diff --git a/README.md b/README.md index 3fe6cc745..d62845d4c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ > :warning: **py4DSTEM version 0.14 update** :warning: Warning: this is a major update and we expect some workflows to break. You can still install previous versions of py4DSTEM [as discussed here](#legacyinstall) -> :warning: **Phase retrieval refactor version 0.14.9** :warning: Warning: The phase-retrieval modules in py4DSTEM (DPC, parallax, and ptychography) underwent a major refactor in version 0.14.9 and as such older tutorial notebooks will not work as expected. Notably, class names have been pruned to remove the trailing "Reconstruction" (`DPCReconstruction` -> `DPC` etc.), and regularization functions have dropped the `_iter` suffix (and are instead specified as boolean flags). We are working on updating the tutorial notebooks to reflect these changes. In the meantime, there's some more information in the relevant pull request [here](https://github.com/py4dstem/py4DSTEM/pull/597#issuecomment-1890325568). +> :warning: **Phase retrieval refactor version 0.14.9** :warning: Warning: The phase-retrieval modules in py4DSTEM (DPC, parallax, and ptychography) underwent a major refactor in version 0.14.9 and as such older tutorial notebooks will not work as expected. Notably, class names have been pruned to remove the trailing "Reconstruction" (`DPCReconstruction` -> `DPC` etc.), and regularization functions have dropped the `_iter` suffix (and are instead specified as boolean flags). See the [updated tutorials](https://github.com/py4dstem/py4DSTEM_tutorials) for more information. ![py4DSTEM logo](/images/py4DSTEM_logo.png) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index d5df63f5e..0f4490eb2 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -1,6 +1,9 @@ from py4DSTEM.version import __version__ from emdfile import tqdmnd +from importlib.metadata import packages_distributions + +is_package_lite = "py4DSTEM-lite" in packages_distributions()["py4DSTEM"] ### io @@ -52,8 +55,11 @@ BraggVectorMap, ) -from py4DSTEM.process import classification - +try: + from py4DSTEM.process import classification +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc # diffraction from py4DSTEM.process.diffraction import Crystal, Orientation @@ -70,7 +76,11 @@ # strain from py4DSTEM.process.strain.strain import StrainMap -from py4DSTEM.process import wholepatternfit +try: + from py4DSTEM.process import wholepatternfit +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc ### more submodules diff --git a/py4DSTEM/braggvectors/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py index 99818b75e..500dbd2e9 100644 --- a/py4DSTEM/braggvectors/diskdetection.py +++ b/py4DSTEM/braggvectors/diskdetection.py @@ -5,12 +5,18 @@ from scipy.ndimage import gaussian_filter from emdfile import tqdmnd +from py4DSTEM import is_package_lite from py4DSTEM.braggvectors.braggvectors import BraggVectors from py4DSTEM.data import QPoints from py4DSTEM.datacube import DataCube from py4DSTEM.preprocess.utils import get_maxima_2D from py4DSTEM.process.utils.cross_correlate import get_cross_correlation_FT -from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml + +try: + from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc def find_Bragg_disks( diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index 627223d23..d4fe15241 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -5,7 +5,6 @@ # for bragg virtual imaging methods, goto diskdetection.virtualimage.py import numpy as np -import dask.array as da from typing import Optional import inspect @@ -220,7 +219,9 @@ def get_virtual_image( virtual_image[rx, ry] = np.sum(self.data[rx, ry] * mask) # dask - if dask is True: + if dask: + import dask.array as da + # set up a generalized universal function for dask distribution def _apply_mask_dask(self, mask): virtual_image = np.sum( diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py index b6f4eb0a2..4b89175a7 100644 --- a/py4DSTEM/io/filereaders/__init__.py +++ b/py4DSTEM/io/filereaders/__init__.py @@ -1,6 +1,12 @@ +from py4DSTEM import is_package_lite +from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_dm import read_dm from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin -from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_mib import load_mib -from py4DSTEM.io.filereaders.read_arina import read_arina + +try: + from py4DSTEM.io.filereaders.read_arina import read_arina +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc from py4DSTEM.io.filereaders.read_abTEM import read_abTEM diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py index ff3d1c37c..b3002c77e 100644 --- a/py4DSTEM/io/importfile.py +++ b/py4DSTEM/io/importfile.py @@ -7,7 +7,6 @@ from py4DSTEM.io.filereaders import ( load_mib, read_abTEM, - read_arina, read_dm, read_empad, read_gatan_K2_bin, @@ -90,6 +89,8 @@ def import_file( elif filetype == "mib": data = load_mib(filepath, mem=mem, binfactor=binfactor, **kwargs) elif filetype == "arina": + from py4DSTEM.io.filereaders import read_arina + data = read_arina(filepath, mem=mem, binfactor=binfactor, **kwargs) elif filetype == "abTEM": data = read_abTEM(filepath, mem=mem, binfactor=binfactor, **kwargs) diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index 0509d181e..6d7d36b28 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,9 +1,21 @@ +from py4DSTEM import is_package_lite from py4DSTEM.process.polar import PolarDatacube from py4DSTEM.process.strain.strain import StrainMap from py4DSTEM.process import phase from py4DSTEM.process import calibration from py4DSTEM.process import utils -from py4DSTEM.process import classification + +try: + from py4DSTEM.process import classification +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc + from py4DSTEM.process import diffraction -from py4DSTEM.process import wholepatternfit + +try: + from py4DSTEM.process import wholepatternfit +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index e0fe59eee..33e6b07b8 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -710,41 +710,69 @@ def generate_diffraction_pattern( zone_axis_cartesian: Optional[np.ndarray] = None, proj_x_cartesian: Optional[np.ndarray] = None, foil_normal_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, - sigma_excitation_error: float = 0.02, + sigma_excitation_error: float = 0.01, tol_excitation_error_mult: float = 3, - tol_intensity: float = 1e-4, + tol_intensity: float = 1e-5, k_max: Optional[float] = None, + precession_angle_degrees=None, keep_qz=False, return_orientation_matrix=False, ): """ - Generate a single diffraction pattern, return all peaks as a pointlist. + Generate a single diffraction pattern, return all peaks as a pointlist. This function performs a + kinematical calculation, with optional precession of the beam. - Args: - orientation (Orientation): an Orientation class object - ind_orientation If input is an Orientation class object with multiple orientations, - this input can be used to select a specific orientation. - - orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions. - zone_axis_lattice (array): (3,) projection direction in lattice indices - proj_x_lattice (array): (3,) x-axis direction in lattice indices - zone_axis_cartesian (array): (3,) cartesian projection direction - proj_x_cartesian (array): (3,) cartesian projection direction - - foil_normal: 3 element foil normal - set to None to use zone_axis - proj_x_axis (np float vector): 3 element vector defining image x axis (vertical) - accel_voltage (float): Accelerating voltage in Volts. If not specified, - we check to see if crystal already has voltage specified. - sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): Maximum scattering vector - keep_qz (bool): Flag to return out-of-plane diffraction vectors - return_orientation_matrix (bool): Return the orientation matrix + TODO - switch from numerical precession to analytic (requires geometry projection). + TODO - verify projection geometry for 2D material diffraction. + + Parameters + ---------- + + orientation (Orientation) + an Orientation class object + ind_orientation + If input is an Orientation class object with multiple orientations, + this input can be used to select a specific orientation. + + orientation_matrix (3,3) numpy.array + orientation matrix, where columns represent projection directions. + zone_axis_lattice (3,) numpy.array + projection direction in lattice indices + proj_x_lattice (3,) numpy.array + x-axis direction in lattice indices + zone_axis_cartesian (3,) numpy.array + cartesian projection direction + proj_x_cartesian (3,) numpy.array + cartesian projection direction + foil_normal + 3 element foil normal - set to None to use zone_axis + proj_x_axis (3,) numpy.array + 3 element vector defining image x axis (vertical) + accel_voltage (float) + Accelerating voltage in Volts. If not specified, + we check to see if crystal already has voltage specified. + sigma_excitation_error (float) + sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms + tol_excitation_error_mult (float) + tolerance in units of sigma for s_g inclusion + tol_intensity (numpy float) + tolerance in intensity units for inclusion of diffraction spots + k_max (float) + Maximum scattering vector + precession_angle_degrees (float) + Precession angle for library calculation. Set to None for no precession. + keep_qz (bool) + Flag to return out-of-plane diffraction vectors + return_orientation_matrix (bool) + Return the orientation matrix + + Returns + ---------- + bragg_peaks (PointList) + list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] + orientation_matrix (array, optional) + 3x3 orientation matrix - Returns: - bragg_peaks (PointList): list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] - orientation_matrix (array): 3x3 orientation matrix (optional) """ if not (hasattr(self, "wavelength") and hasattr(self, "accel_voltage")): @@ -779,17 +807,27 @@ def generate_diffraction_pattern( # Calculate excitation errors if foil_normal is None: - sg = self.excitation_errors(g) + sg = self.excitation_errors( + g, + precession_angle_degrees=precession_angle_degrees, + ) else: foil_normal = ( orientation_matrix.T @ (-1 * foil_normal[:, None] / np.linalg.norm(foil_normal)) ).ravel() - sg = self.excitation_errors(g, foil_normal) + sg = self.excitation_errors( + g, + foil_normal=foil_normal, + precession_angle_degrees=precession_angle_degrees, + ) # Threshold for inclusion in diffraction pattern sg_max = sigma_excitation_error * tol_excitation_error_mult - keep = np.abs(sg) <= sg_max + if precession_angle_degrees is None: + keep = np.abs(sg) <= sg_max + else: + keep = np.min(np.abs(sg), axis=1) <= sg_max # Maximum scattering angle cutoff if k_max is not None: @@ -799,9 +837,15 @@ def generate_diffraction_pattern( g_diff = g[:, keep] # Diffracted peak intensities and labels - g_int = self.struct_factors_int[keep] * np.exp( - (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) - ) + if precession_angle_degrees is None: + g_int = self.struct_factors_int[keep] * np.exp( + (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) + ) + else: + g_int = self.struct_factors_int[keep] * np.mean( + np.exp((sg[keep] ** 2) / (-2 * sigma_excitation_error**2)), + axis=1, + ) hkl = self.hkl[:, keep] # Intensity tolerance @@ -975,6 +1019,7 @@ def generate_projected_potential( potential_radius_angstroms=3.0, sigma_image_blur_angstroms=0.1, thickness_angstroms=100, + max_num_proj=200, power_scale=1.0, plot_result=False, figsize=(6, 6), @@ -989,9 +1034,6 @@ def generate_projected_potential( """ Generate an image of the projected potential of crystal in real space, using cell tiling, and a lookup table of the atomic potentials. - Note that we round atomic positions to the nearest pixel for speed. - - TODO - fix scattering prefactor so that output units are sensible. Parameters ---------- @@ -1006,6 +1048,9 @@ def generate_projected_potential( thickness_angstroms: float Thickness of the sample in Angstroms. Set thickness_thickness_angstroms = 0 to skip thickness projection. + max_num_proj: int + This value prevents this function from projecting a large number of unit + cells along the beam direction, which could be potentially quite slow. power_scale: float Power law scaling of potentials. Set to 2.0 to approximate Z^2 images. plot_result: bool @@ -1054,52 +1099,22 @@ def generate_projected_potential( # Rotate unit cell into projection direction lat_real = self.lat_real.copy() @ orientation_matrix - # Determine unit cell axes to tile over, by selecting 2/3 with largest in-plane component - inds_tile = np.argsort(np.linalg.norm(lat_real[:, 0:2], axis=1))[1:3] + # Determine unit cell axes to tile over, by selecting 2/3 with smallest out-of-plane component + inds_tile = np.argsort(np.abs(lat_real[:, 2]))[0:2] m_tile = lat_real[inds_tile, :] + # Vector projected along optic axis m_proj = np.squeeze(np.delete(lat_real, inds_tile, axis=0)) - # Thickness - if thickness_angstroms > 0: - num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype("int") - if num_proj > 1: - vec_proj = m_proj[:2] / pixel_size_angstroms - shifts = np.arange(num_proj).astype("float") - shifts -= np.mean(shifts) - x_proj = shifts * vec_proj[0] - y_proj = shifts * vec_proj[1] - else: - num_proj = 1 - else: - num_proj = 1 - # Determine tiling range - if thickness_angstroms > 0: - # include the cell height - dz = m_proj[2] * num_proj * 0.5 - p_corners = np.array( - [ - [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, dz], - [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, dz], - [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, dz], - [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, dz], - [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, -dz], - [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, -dz], - [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, -dz], - [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, -dz], - ] - ) - else: - p_corners = np.array( - [ - [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], - [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], - [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], - [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], - ] - ) - + p_corners = np.array( + [ + [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + ] + ) ab = np.linalg.lstsq(m_tile[:, :2].T, p_corners[:, :2].T, rcond=None)[0] ab = np.floor(ab) a_range = np.array((np.min(ab[0]) - 1, np.max(ab[0]) + 2)) @@ -1115,31 +1130,17 @@ def generate_projected_potential( abc_atoms[:, inds_tile[0]] += a_ind.ravel() abc_atoms[:, inds_tile[1]] += b_ind.ravel() xyz_atoms_ang = abc_atoms @ lat_real - atoms_ID_all_0 = self.numbers[atoms_ind.ravel()] + atoms_ID_all = self.numbers[atoms_ind.ravel()] # Center atoms on image plane - x0 = xyz_atoms_ang[:, 0] / pixel_size_angstroms + im_size[0] / 2.0 - y0 = xyz_atoms_ang[:, 1] / pixel_size_angstroms + im_size[1] / 2.0 - - # if needed, tile atoms in the projection direction - if num_proj > 1: - x = (x0[:, None] + x_proj[None, :]).ravel() - y = (y0[:, None] + y_proj[None, :]).ravel() - atoms_ID_all = np.tile(atoms_ID_all_0, (num_proj, 1)) - else: - x = x0 - y = y0 - atoms_ID_all = atoms_ID_all_0 - # print(x.shape, y.shape) - - # delete atoms outside the field of view - bound = potential_radius_angstroms / pixel_size_angstroms + x = xyz_atoms_ang[:, 0] / pixel_size_angstroms + im_size[0] / 2.0 + y = xyz_atoms_ang[:, 1] / pixel_size_angstroms + im_size[1] / 2.0 atoms_del = np.logical_or.reduce( ( - x <= -bound, - y <= -bound, - x >= im_size[0] + bound, - y >= im_size[1] + bound, + x <= -potential_radius_angstroms / 2, + y <= -potential_radius_angstroms / 2, + x >= im_size[0] + potential_radius_angstroms / 2, + y >= im_size[1] + potential_radius_angstroms / 2, ) ) x = np.delete(x, atoms_del) @@ -1164,16 +1165,18 @@ def generate_projected_potential( for a0 in range(atoms_ID.shape[0]): atom_sf = single_atom_scatter([atoms_ID[a0]]) atoms_lookup[a0, :, :] = atom_sf.projected_potential(atoms_ID[a0], R_2D) - - # if needed, apply gaussian blurring to each atom - if sigma_image_blur_angstroms > 0: - atoms_lookup[a0, :, :] = gaussian_filter( - atoms_lookup[a0, :, :], - sigma_image_blur_angstroms / pixel_size_angstroms, - mode="nearest", - ) atoms_lookup **= power_scale + # Thickness + if thickness_angstroms > 0: + thickness_proj = thickness_angstroms / m_proj[2] + vec_proj = thickness_proj / pixel_size_angstroms * m_proj[:2] + num_proj = (np.ceil(np.linalg.norm(vec_proj)) + 1).astype("int") + num_proj = np.minimum(num_proj, max_num_proj) + + x_proj = np.linspace(-0.5, 0.5, num_proj) * vec_proj[0] + y_proj = np.linspace(-0.5, 0.5, num_proj) * vec_proj[1] + # initialize potential im_potential = np.zeros(im_size) @@ -1181,41 +1184,68 @@ def generate_projected_potential( for a0 in range(atoms_ID_all.shape[0]): ind = np.argmin(np.abs(atoms_ID - atoms_ID_all[a0])) - x_ind = np.round(x[a0]).astype("int") + R_ind - y_ind = np.round(y[a0]).astype("int") + R_ind - x_sub = np.logical_and( - x_ind >= 0, - x_ind < im_size[0], - ) - y_sub = np.logical_and( - y_ind >= 0, - y_ind < im_size[1], - ) - im_potential[x_ind[x_sub][:, None], y_ind[y_sub][None, :]] += atoms_lookup[ - ind - ][x_sub][:, y_sub] + if thickness_angstroms > 0: + for a1 in range(num_proj): + x_ind = np.round(x[a0] + x_proj[a1]).astype("int") + R_ind + y_ind = np.round(y[a0] + y_proj[a1]).astype("int") + R_ind + x_sub = np.logical_and( + x_ind >= 0, + x_ind < im_size[0], + ) + y_sub = np.logical_and( + y_ind >= 0, + y_ind < im_size[1], + ) + + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] + + else: + x_ind = np.round(x[a0]).astype("int") + R_ind + y_ind = np.round(y[a0]).astype("int") + R_ind + x_sub = np.logical_and( + x_ind >= 0, + x_ind < im_size[0], + ) + y_sub = np.logical_and( + y_ind >= 0, + y_ind < im_size[1], + ) + + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] if thickness_angstroms > 0: im_potential /= num_proj + # if needed, apply gaussian blurring + if sigma_image_blur_angstroms > 0: + sigma_image_blur = sigma_image_blur_angstroms / pixel_size_angstroms + im_potential = gaussian_filter( + im_potential, + sigma_image_blur, + mode="nearest", + ) + if plot_result: # quick plotting of the result int_vals = np.sort(im_potential.ravel()) int_range = np.array( ( int_vals[np.round(0.02 * int_vals.size).astype("int")], - int_vals[np.round(0.999 * int_vals.size).astype("int")], + int_vals[np.round(0.98 * int_vals.size).astype("int")], ) ) fig, ax = plt.subplots(figsize=figsize) ax.imshow( im_potential, - cmap="gray", + cmap="turbo", vmin=int_range[0], vmax=int_range[1], ) - # ax.scatter(y,x,c='r') # for testing ax.set_axis_off() ax.set_aspect("equal") @@ -1323,20 +1353,55 @@ def excitation_errors( self, g, foil_normal=None, + precession_angle_degrees=None, + precession_steps=72, ): """ Calculate the excitation errors, assuming k0 = [0, 0, -1/lambda]. If foil normal is not specified, we assume it is [0,0,-1]. + + Precession is currently implemented using numerical integration. """ - if foil_normal is None: - return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( - 2 - 2 * self.wavelength * g[2, :] - ) + + if precession_angle_degrees is None: + if foil_normal is None: + return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( + 2 - 2 * self.wavelength * g[2, :] + ) + else: + return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( + 2 * self.wavelength * np.sum(g * foil_normal[:, None], axis=0) + - 2 * foil_normal[2] + ) + else: - return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( - 2 * self.wavelength * np.sum(g * foil_normal[:, None], axis=0) - - 2 * foil_normal[2] + t = np.deg2rad(precession_angle_degrees) + p = np.linspace( + 0, + 2.0 * np.pi, + precession_steps, + endpoint=False, ) + if foil_normal is None: + foil_normal = np.array((0.0, 0.0, -1.0)) + + k = np.reshape( + (-1 / self.wavelength) + * np.vstack( + ( + np.sin(t) * np.cos(p), + np.sin(t) * np.sin(p), + np.cos(t) * np.ones(p.size), + ) + ), + (3, 1, p.size), + ) + + term1 = np.sum((g[:, :, None] + k) * foil_normal[:, None, None], axis=0) + term2 = np.sum((g[:, :, None] + 2 * k) * g[:, :, None], axis=0) + sg = np.sqrt(term1**2 - term2) - term1 + + return sg def calculate_bragg_peak_histogram( self, diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index fc6e691c3..f7a4a20db 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -27,8 +27,11 @@ def orientation_plan( angle_step_in_plane: float = 2.0, accel_voltage: float = 300e3, corr_kernel_size: float = 0.08, - radial_power: float = 1.0, - intensity_power: float = 0.25, # New default intensity power scaling + sigma_excitation_error: float = 0.02, + precession_angle_degrees=None, + power_radial: float = 1.0, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -41,39 +44,66 @@ def orientation_plan( """ Calculate the rotation basis arrays for an SO(3) rotation correlogram. - Args: - zone_axis_range (float): Row vectors give the range for zone axis orientations. - If user specifies 2 vectors (2x3 array), we start at [0,0,1] - to make z-x-z rotation work. - If user specifies 3 vectors (3x3 array), plan will span these vectors. - Setting to 'full' as a string will use a hemispherical range. - Setting to 'half' as a string will use a quarter sphere range. - Setting to 'fiber' as a string will make a spherical cap around a given vector. - Setting to 'auto' will use pymatgen to determine the point group symmetry - of the structure and choose an appropriate zone_axis_range - angle_step_zone_axis (float): Approximate angular step size for zone axis search [degrees] - angle_coarse_zone_axis (float): Coarse step size for zone axis search [degrees]. Setting to - None uses the same value as angle_step_zone_axis. - angle_refine_range (float): Range of angles to use for zone axis refinement. Setting to - None uses same value as angle_coarse_zone_axis. - - angle_step_in_plane (float): Approximate angular step size for in-plane rotation [degrees] - accel_voltage (float): Accelerating voltage for electrons [Volts] - corr_kernel_size (float): Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] - radial_power (float): Power for scaling the correlation intensity as a function of the peak radius - intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity - calculate_correlation_array (bool): Set to false to skip calculating the correlation array. - This is useful when we only want the angular range / rotation matrices. - tol_peak_delete (float): Distance to delete peaks for multiple matches. - Default is kernel_size * 0.5 - tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] - fiber_axis (float): (3,) vector specifying the fiber axis - fiber_angles (float): (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] - cartesian_directions (bool): When set to true, all zone axes and projection directions - are specified in Cartesian directions. - figsize (float): (2,) vector giving the figure size - CUDA (bool): Use CUDA for the Fourier operations. - progress_bar (bool): If false no progress bar is displayed + Parameters + ---------- + zone_axis_range (float): + Row vectors give the range for zone axis orientations. + If user specifies 2 vectors (2x3 array), we start at [0,0,1] + to make z-x-z rotation work. + If user specifies 3 vectors (3x3 array), plan will span these vectors. + Setting to 'full' as a string will use a hemispherical range. + Setting to 'half' as a string will use a quarter sphere range. + Setting to 'fiber' as a string will make a spherical cap around a given vector. + Setting to 'auto' will use pymatgen to determine the point group symmetry + of the structure and choose an appropriate zone_axis_range + angle_step_zone_axis (float): + Approximate angular step size for zone axis search [degrees] + angle_coarse_zone_axis (float): + Coarse step size for zone axis search [degrees]. Setting to + None uses the same value as angle_step_zone_axis. + angle_refine_range (float): + Range of angles to use for zone axis refinement. Setting to + None uses same value as angle_coarse_zone_axis. + + angle_step_in_plane (float): + Approximate angular step size for in-plane rotation [degrees] + accel_voltage (float): + Accelerating voltage for electrons [Volts] + corr_kernel_size (float): + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error (float): + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + + power_radial (float): + Power for scaling the correlation intensity as a function of the peak radius + power_intensity (float): + Power for scaling the correlation intensity as a function of simulated peak intensity + power_intensity_experiment (float): + Power for scaling the correlation intensity as a function of experimental peak intensity + calculate_correlation_array (bool): + Set to false to skip calculating the correlation array. + This is useful when we only want the angular range / rotation matrices. + tol_peak_delete (float): + Distance to delete peaks for multiple matches. + Default is kernel_size * 0.5 + tol_distance (float): + Distance tolerance for radial shell assignment [1/Angstroms] + fiber_axis (float): + (3,) vector specifying the fiber axis + fiber_angles (float): + (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] + cartesian_directions (bool): + When set to true, all zone axes and projection directions + are specified in Cartesian directions. + figsize (float): + (2,) vector giving the figure size + CUDA (bool): + Use CUDA for the Fourier operations. + progress_bar (bool): + If false no progress bar is displayed, """ # Check to make sure user has calculated the structure factors if needed @@ -86,6 +116,14 @@ def orientation_plan( # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) + self.orientation_sigma_excitation_error = sigma_excitation_error + if precession_angle_degrees is None: + self.orientation_precession_angle_degrees = None + else: + self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) + self.orientation_precession_angle = np.deg2rad( + np.asarray(precession_angle_degrees) + ) if tol_peak_delete is None: self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5 else: @@ -104,8 +142,9 @@ def orientation_plan( self.wavelength = electron_wavelength_angstrom(self.accel_voltage) # store the radial and intensity scaling to use later for generating test patterns - self.orientation_radial_power = radial_power - self.orientation_intensity_power = intensity_power + self.orientation_power_radial = power_radial + self.orientation_power_intensity = power_intensity + self.orientation_power_intensity_experiment = power_intensity_experiment # Calculate the ratio between coarse and fine refinement if angle_coarse_zone_axis is not None: @@ -717,45 +756,100 @@ def orientation_plan( ): # reciprocal lattice spots and excitation errors g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all + # if precession_angle_degrees is None: sg = self.excitation_errors(g) + # else: + # sg = np.min( + # np.abs( + # self.excitation_errors( + # g, + # precession_angle_degrees = precession_angle_degrees, + # ), + # ), + # axis = 1, + # ) # Keep only points that will contribute to this orientation plan slice - keep = np.abs(sg) < self.orientation_kernel_size + keep = np.logical_and( + np.abs(sg) < self.orientation_kernel_size, + self.orientation_shell_index >= 0, + ) + + # calculate intensity of spots + if precession_angle_degrees is None: + Ig = np.exp(sg[keep] ** 2 / (-2 * sigma_excitation_error**2)) + else: + # precession extension + prec = np.cos(np.linspace(0, 2 * np.pi, 90, endpoint=False)) + dsg = np.tan(self.orientation_precession_angle) * np.sum( + g[:2, keep] ** 2, axis=0 + ) + Ig = np.mean( + np.exp( + (sg[keep, None] + dsg[:, None] * prec[None, :]) ** 2 + / (-2 * sigma_excitation_error**2) + ), + axis=1, + ) # in-plane rotation angle - phi = np.arctan2(g[1, :], g[0, :]) - - # Loop over all peaks - for a1 in np.arange(self.g_vec_all.shape[1]): - ind_radial = self.orientation_shell_index[a1] - - if keep[a1] and ind_radial >= 0: - # 2D orientation plan - self.orientation_ref[a0, ind_radial, :] += ( - np.power(self.orientation_shell_radii[ind_radial], radial_power) - * np.power(self.struct_factors_int[a1], intensity_power) - * np.maximum( - 1 - - np.sqrt( - sg[a1] ** 2 - + ( - ( - np.mod( - self.orientation_gamma - phi[a1] + np.pi, - 2 * np.pi, - ) - - np.pi - ) - * self.orientation_shell_radii[ind_radial] - ) - ** 2 - ) - / self.orientation_kernel_size, - 0, - ) - ) + phi = np.arctan2(g[1, keep], g[0, keep]) + phi_ind = phi / self.orientation_gamma[1] # step size of annular bins + phi_floor = np.floor(phi_ind).astype("int") + dphi = phi_ind - phi_floor + + # write intensities into orientation plan slice + radial_inds = self.orientation_shell_index[keep] + self.orientation_ref[a0, radial_inds, phi_floor] += ( + (1 - dphi) + * np.power(self.struct_factors_int[keep] * Ig, power_intensity) + * np.power(self.orientation_shell_radii[radial_inds], power_radial) + ) + self.orientation_ref[ + a0, radial_inds, np.mod(phi_floor + 1, self.orientation_in_plane_steps) + ] += ( + dphi + * np.power(self.struct_factors_int[keep] * Ig, power_intensity) + * np.power(self.orientation_shell_radii[radial_inds], power_radial) + ) + # # Loop over all peaks + # for a1 in np.arange(self.g_vec_all.shape[1]): + # if keep[a1]: + + # for a1 in np.arange(self.g_vec_all.shape[1]): + # ind_radial = self.orientation_shell_index[a1] + + # if keep[a1] and ind_radial >= 0: + # # 2D orientation plan + # self.orientation_ref[a0, ind_radial, :] += ( + # np.power(self.orientation_shell_radii[ind_radial], power_radial) + # * np.power(self.struct_factors_int[a1], power_intensity) + # * np.maximum( + # 1 + # - np.sqrt( + # sg[a1] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma - phi[a1] + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * self.orientation_shell_radii[ind_radial] + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ) + # ) + + # normalization + # self.orientation_ref[a0, :, :] -= np.mean(self.orientation_ref[a0, :, :]) orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) + # orientation_ref_norm = np.sum(self.orientation_ref[a0, :, :]) if orientation_ref_norm > 0: self.orientation_ref[a0, :, :] /= orientation_ref_norm @@ -959,14 +1053,12 @@ def match_single_pattern( if np.any(sub): im_polar[ind_radial, :] = np.sum( - np.power(radius, self.orientation_radial_power) - * np.power( + np.power( np.maximum(intensity[sub, None], 0.0), - self.orientation_intensity_power, + self.orientation_power_intensity_experiment, ) - * np.maximum( - 1 - - np.sqrt( + * np.exp( + ( dqr[sub, None] ** 2 + ( ( @@ -982,12 +1074,73 @@ def match_single_pattern( ) ** 2 ) - / self.orientation_kernel_size, - 0, + / (-2 * self.orientation_kernel_size**2) ), axis=0, ) + # im_polar[ind_radial, :] = np.sum( + # np.power( + # np.maximum(intensity[sub, None], 0.0), + # self.orientation_power_intensity_experiment, + # ) + # * np.maximum( + # 1 + # - np.sqrt( + # dqr[sub, None] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma[None, :] + # - qphi[sub, None] + # + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * radius + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ), + # axis=0, + # ) + + # im_polar[ind_radial, :] = np.sum( + # np.power(radius, self.orientation_power_radial) + # * np.power( + # np.maximum(intensity[sub, None], 0.0), + # self.orientation_power_intensity, + # ) + # * np.maximum( + # 1 + # - np.sqrt( + # dqr[sub, None] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma[None, :] + # - qphi[sub, None] + # + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * radius + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ), + # axis=0, + # ) + + # normalization + # im_polar -= np.mean(im_polar) + # Determine the RMS signal from im_polar for the first match. # Note that we use scaling slightly below RMS so that following matches # don't have higher correlating scores than previous matches. @@ -1186,6 +1339,7 @@ def match_single_pattern( np.clip( np.sum( self.orientation_vecs + # self.orientation_vecs * np.array([1,-1,-1])[None,:] * self.orientation_vecs[inds_previous[a0], :], axis=1, ), @@ -1484,7 +1638,8 @@ def match_single_pattern( bragg_peaks_fit = self.generate_diffraction_pattern( orientation, ind_orientation=match_ind, - sigma_excitation_error=self.orientation_kernel_size, + sigma_excitation_error=self.orientation_sigma_excitation_error, + precession_angle_degrees=self.orientation_precession_angle_degrees, ) remove = np.zeros_like(qx, dtype="bool") diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index d28616aa9..c161304f9 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -3,13 +3,14 @@ from scipy.optimize import nnls import matplotlib as mpl import matplotlib.pyplot as plt +from scipy.ndimage import gaussian_filter from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -class Crystal_Phase: +class CrystalPhase: """ A class storing multiple crystal structures, and associated diffraction data. Must be initialized after matching orientations to a pointlistarray??? @@ -19,8 +20,9 @@ class Crystal_Phase: def __init__( self, crystals, - orientation_maps, - name, + crystal_names=None, + orientation_maps=None, + name=None, ): """ Args: @@ -33,306 +35,1368 @@ def __init__( self.num_crystals = len(crystals) else: raise TypeError("crystals must be a list of crystal instances.") - if isinstance(orientation_maps, list): + + # List of orientation maps + if orientation_maps is None: + self.orientation_maps = [ + crystals[ind].orientation_map for ind in range(self.num_crystals) + ] + else: if len(self.crystals) != len(orientation_maps): raise ValueError( "Orientation maps must have the same number of entries as crystals." ) self.orientation_maps = orientation_maps + + # Names of all crystal phases + if crystal_names is None: + self.crystal_names = [ + "crystal" + str(ind) for ind in range(self.num_crystals) + ] + else: + self.crystal_names = crystal_names + + # Name of the phase map + if name is None: + self.name = "phase map" else: - raise TypeError("orientation_maps must be a list of orientation maps.") - self.name = name - return + self.name = name - def plot_all_phase_maps(self, map_scale_values=None, index=0): + # Get some attributes from crystals + self.k_max = np.zeros(self.num_crystals) + self.num_matches = np.zeros(self.num_crystals, dtype="int") + self.crystal_identity = np.zeros((0, 2), dtype="int") + for a0 in range(self.num_crystals): + self.k_max[a0] = self.crystals[a0].k_max + self.num_matches[a0] = self.crystals[a0].orientation_map.num_matches + for a1 in range(self.num_matches[a0]): + self.crystal_identity = np.append( + self.crystal_identity, + np.array((a0, a1), dtype="int")[None, :], + axis=0, + ) + + self.num_fits = np.sum(self.num_matches) + + def quantify_single_pattern( + self, + pointlistarray: PointListArray, + xy_position=(0, 0), + corr_kernel_size=0.04, + sigma_excitation_error: float = 0.02, + precession_angle_degrees=None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max=None, + max_number_patterns=2, + single_phase=False, + allow_strain=False, + strain_iterations=3, + strain_max=0.02, + include_false_positives=True, + weight_false_positives=1.0, + weight_unmatched_peaks=1.0, + plot_result=True, + plot_only_nonzero_phases=True, + plot_unmatched_peaks=False, + plot_correlation_radius=False, + scale_markers_experiment=40, + scale_markers_calculated=200, + crystal_inds_plot=None, + phase_colors=None, + figsize=(10, 7), + verbose=True, + returnfig=False, + ): """ - Visualize phase maps of dataset. + Quantify the phase for a single diffraction pattern. + + TODO - determine the difference between false positive peaks and unmatched peaks (if any). + + Parameters + ---------- + + pointlistarray: (PointListArray) + Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) + xy_position: (int,int) + The (x,y) or (row,column) position to be quantified. + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error: (float) + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees: (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + power_intensity: (float) + Power for scaling the correlation intensity as a function of simulated peak intensity. + power_intensity_experiment: (float): + Power for scaling the correlation intensity as a function of experimental peak intensity. + k_max: (float) + Max k values included in fits, for both x and y directions. + max_number_patterns: int + Max number of orientations which can be included in a match. + single_phase: bool + Set to true to force result to output only the best-fit phase (minimum intensity residual). + allow_strain: bool, + Allow the simulated diffraction patterns to be distorted to improve the matches. + strain_iterations: int + Number of pattern position refinement iterations. + strain_max: float + Maximum strain fraction allowed - this value should be low, typically a few percent (~0.02). + include_false_positives: bool + Penalize patterns which generate false positive peaks. + weight_false_positives: float + Weight strength of false positive peaks. + weight_unmatched_peaks: float + Penalize unmatched peaks. + plot_result: bool + Plot the resulting fit. + plot_only_nonzero_phases: bool + Only plot phases with phase weights > 0. + plot_unmatched_peaks: bool + Plot the false postive peaks. + plot_correlation_radius: bool + In the visualization, draw the correlation radius. + scale_markers_experiment: float + Size of experimental diffraction peak markers. + scale_markers_calculated: float + Size of the calculate diffraction peak markers. + crystal_inds_plot: tuple of ints + Which crystal index / indices to plot. + phase_colors: np.array + Color of each phase, should have shape = (num_phases, 3) + figsize: (float,float) + Size of the output figure. + verbose: bool + Print the resulting fit weights to console. + returnfig: bool + Return the figure and axis handles for the plot. + + + Returns + ------- + phase_weights: (np.array) + Estimated relative fraction of each phase for all probe positions. + shape = (num_x, num_y, num_orientations) + where num_orientations is the total number of all orientations for all phases. + phase_residual: (np.array) + Residual intensity not represented by the best fit phase weighting for all probe positions. + shape = (num_x, num_y) + phase_reliability: (np.array) + Estimated reliability of match(es) for all probe positions. + Typically calculated as the best fit score minus the second best fit. + shape = (num_x, num_y) + int_total: (np.array) + Sum of experimental peak intensities for all probe positions. + shape = (num_x, num_y) + fig,ax: (optional) + matplotlib figure and axis handles - Args: - map_scale_values (float): Value to scale correlations by """ - phase_maps = [] - if map_scale_values is None: - map_scale_values = [1] * len(self.orientation_maps) - corr_sum = np.sum( - [ - (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m]) - for m in range(len(self.orientation_maps)) - ] - ) - for m in range(len(self.orientation_maps)): - phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum) - show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno") - return - - def plot_phase_map(self, index=0, cmap=None): - corr_array = np.dstack( - [maps.corr[:, :, index] for maps in self.orientation_maps] - ) - best_corr_score = np.max(corr_array, axis=2) - best_match_phase = [ - np.where(corr_array[:, :, p] == best_corr_score, True, False) - for p in range(len(self.orientation_maps)) - ] - - if cmap is None: - cm = plt.get_cmap("rainbow") - cmap = [ - cm(1.0 * i / len(self.orientation_maps)) - for i in range(len(self.orientation_maps)) - ] - fig, (ax) = plt.subplots(figsize=(6, 6)) - ax.matshow( - np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), - cmap="gray", + # tolerance for separating the origin peak. + tolerance_origin_2 = 1e-6 + + # calibrations + center = pointlistarray.calstate["center"] + ellipse = pointlistarray.calstate["ellipse"] + pixel = pointlistarray.calstate["pixel"] + rotate = pointlistarray.calstate["rotate"] + if center is False: + raise ValueError("Bragg peaks must be center calibration") + if pixel is False: + raise ValueError("Bragg peaks must have pixel size calibration") + # TODO - potentially warn the user if ellipse / rotate calibration not available + + if phase_colors is None: + phase_colors = np.array( + ( + (1.0, 0.0, 0.0, 1.0), + (0.0, 0.8, 1.0, 1.0), + (0.0, 0.6, 0.0, 1.0), + (1.0, 0.0, 1.0, 1.0), + (0.0, 0.2, 1.0, 1.0), + (1.0, 0.8, 0.0, 1.0), + ) + ) + + # Experimental values + bragg_peaks = pointlistarray.get_vectors( + xy_position[0], + xy_position[1], + center=center, + ellipse=ellipse, + pixel=pixel, + rotate=rotate, ) - ax.axis("off") + # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() + if k_max is None: + keep = ( + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 + > tolerance_origin_2 + ) + else: + keep = np.logical_and.reduce( + ( + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 + > tolerance_origin_2, + np.abs(bragg_peaks.data["qx"]) < k_max, + np.abs(bragg_peaks.data["qy"]) < k_max, + ) + ) - for m in range(len(self.orientation_maps)): - c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[ - m - ] - cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10) - ax.matshow( - np.ma.array( - self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m] + # ind_center_beam = np.argmin( + # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) + # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') + # mask[ind_center_beam] = False + # bragg_peaks.remove(ind_center_beam) + qx = bragg_peaks.data["qx"][keep] + qy = bragg_peaks.data["qy"][keep] + qx0 = bragg_peaks.data["qx"][np.logical_not(keep)] + qy0 = bragg_peaks.data["qy"][np.logical_not(keep)] + if power_intensity_experiment == 0: + intensity = np.ones_like(qx) + intensity0 = np.ones_like(qx0) + else: + intensity = ( + bragg_peaks.data["intensity"][keep] ** power_intensity_experiment + ) + intensity0 = ( + bragg_peaks.data["intensity"][np.logical_not(keep)] + ** power_intensity_experiment + ) + int_total = np.sum(intensity) + + # init basis array + if include_false_positives: + basis = np.zeros((intensity.shape[0], self.num_fits)) + unpaired_peaks = [] + else: + basis = np.zeros((intensity.shape[0], self.num_fits)) + if allow_strain: + m_strains = np.zeros((self.num_fits, 2, 2)) + m_strains[:, 0, 0] = 1.0 + m_strains[:, 1, 1] = 1.0 + + # kernel radius squared + radius_max_2 = corr_kernel_size**2 + + # init for plotting + if plot_result: + library_peaks = [] + library_int = [] + library_matches = [] + + # Generate point list data, match to experimental peaks + for a0 in range(self.num_fits): + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] + # for c in range(self.num_crystals): + # for m in range(self.num_matches[c]): + # ind_match += 1 + + # Generate simulated peaks + bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + self.crystals[c].orientation_map.get_orientation( + xy_position[0], xy_position[1] ), - cmap=cm, + ind_orientation=m, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, ) - plt.show() - - return - - # Potentially introduce a way to check best match out of all orientations in phase plan and plug into model - # to quantify phase - - # def phase_plan( - # self, - # method, - # zone_axis_range: np.ndarray = np.array([[0, 1, 1], [1, 1, 1]]), - # angle_step_zone_axis: float = 2.0, - # angle_coarse_zone_axis: float = None, - # angle_refine_range: float = None, - # angle_step_in_plane: float = 2.0, - # accel_voltage: float = 300e3, - # intensity_power: float = 0.25, - # tol_peak_delete=None, - # tol_distance: float = 0.01, - # fiber_axis = None, - # fiber_angles = None, - # ): - # return + if k_max is None: + del_peak = ( + bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2 + < tolerance_origin_2 + ) + else: + del_peak = np.logical_or.reduce( + ( + bragg_peaks_fit.data["qx"] ** 2 + + bragg_peaks_fit.data["qy"] ** 2 + < tolerance_origin_2, + np.abs(bragg_peaks_fit.data["qx"]) > k_max, + np.abs(bragg_peaks_fit.data["qy"]) > k_max, + ) + ) + bragg_peaks_fit.remove(del_peak) + + # peak intensities + if power_intensity == 0: + int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) + else: + int_fit = bragg_peaks_fit.data["intensity"] ** power_intensity + + # Pair peaks to experiment + if plot_result: + matches = np.zeros((bragg_peaks_fit.data.shape[0]), dtype="bool") + + if allow_strain: + for a1 in range(strain_iterations): + # Initial peak pairing to find best-fit strain distortion + pair_sub = np.zeros(bragg_peaks_fit.data.shape[0], dtype="bool") + pair_inds = np.zeros(bragg_peaks_fit.data.shape[0], dtype="int") + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data["qx"][a1] - qx) ** 2 + ( + bragg_peaks_fit.data["qy"][a1] - qy + ) ** 2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + pair_sub[a1] = True + pair_inds[a1] = ind_min + + # calculate best-fit strain tensor, weighted by the intensities. + # requires at least 4 peak pairs + if np.sum(pair_sub) >= 4: + # pair_obs = bragg_peaks_fit.data[['qx','qy']][pair_sub] + pair_basis = np.vstack( + ( + bragg_peaks_fit.data["qx"][pair_sub], + bragg_peaks_fit.data["qy"][pair_sub], + ) + ).T + pair_obs = np.vstack( + ( + qx[pair_inds[pair_sub]], + qy[pair_inds[pair_sub]], + ) + ).T + + # weights + dists = np.sqrt( + ( + bragg_peaks_fit.data["qx"][pair_sub] + - qx[pair_inds[pair_sub]] + ) + ** 2 + + ( + bragg_peaks_fit.data["qx"][pair_sub] + - qx[pair_inds[pair_sub]] + ) + ** 2 + ) + weights = np.sqrt( + int_fit[pair_sub] * intensity[pair_inds[pair_sub]] + ) * (1 - dists / corr_kernel_size) + # weights = 1 - dists / corr_kernel_size + + # strain tensor + m_strain = np.linalg.lstsq( + pair_basis * weights[:, None], + pair_obs * weights[:, None], + rcond=None, + )[0] + + # Clamp strains to be within the user-specified limit + m_strain = np.clip( + m_strain, + np.eye(2) - strain_max, + np.eye(2) + strain_max, + ) + m_strains[a0] *= m_strain + + # Transformed peak positions + qx_copy = bragg_peaks_fit.data["qx"] + qy_copy = bragg_peaks_fit.data["qy"] + bragg_peaks_fit.data["qx"] = ( + qx_copy * m_strain[0, 0] + qy_copy * m_strain[1, 0] + ) + bragg_peaks_fit.data["qy"] = ( + qx_copy * m_strain[0, 1] + qy_copy * m_strain[1, 1] + ) + + # Loop over all peaks, pair experiment to library + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data["qx"][a1] - qx) ** 2 + ( + bragg_peaks_fit.data["qy"][a1] - qy + ) ** 2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if include_false_positives: + weight = np.clip( + 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size, 0, 1 + ) + basis[ind_min, a0] = int_fit[a1] * weight + unpaired_peaks.append( + [ + a0, + int_fit[a1] * (1 - weight), + ] + ) + if weight > 1e-8 and plot_result: + matches[a1] = True + else: + if val_min < radius_max_2: + basis[ind_min, a0] = int_fit[a1] + if plot_result: + matches[a1] = True + + # if val_min < radius_max_2: + # # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + # # weight = 1 + corr_distance_scale * \ + # # np.sqrt(dist2[ind_min]) / corr_kernel_size + # # basis[ind_min,a0] = weight * int_fit[a1] + # basis[ind_min,a0] = int_fit[a1] + # if plot_result: + # matches[a1] = True + # elif include_false_positives: + # # unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) + # unpaired_peaks.append([a0,int_fit[a1]]) + + if plot_result: + library_peaks.append(bragg_peaks_fit) + library_int.append(int_fit) + library_matches.append(matches) + + # If needed, augment basis and observations with false positives + if include_false_positives: + basis_aug = np.zeros((len(unpaired_peaks), self.num_fits)) + for a0 in range(len(unpaired_peaks)): + basis_aug[a0, unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] + + basis = np.vstack((basis, basis_aug * weight_false_positives)) + obs = np.hstack((intensity, np.zeros(len(unpaired_peaks)))) + + else: + obs = intensity + + # Solve for phase weight coefficients + try: + phase_weights = np.zeros(self.num_fits) + + if single_phase: + # loop through each crystal structure and determine the best fit structure, + # which can contain multiple orientations up to max_number_patterns + crystal_res = np.zeros(self.num_crystals) + + for a0 in range(self.num_crystals): + inds_solve = self.crystal_identity[:, 0] == a0 + search = True + + while search is True: + basis_solve = basis[:, inds_solve] + obs_solve = obs.copy() + + if weight_unmatched_peaks > 1.0: + sub_unmatched = np.sum(basis_solve, axis=1) < 1e-8 + obs_solve[sub_unmatched] *= weight_unmatched_peaks + + phase_weights_cand, phase_residual_cand = nnls( + basis_solve, + obs_solve, + ) + + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): + phase_weights[inds_solve] = phase_weights_cand + crystal_res[a0] = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + + ind_best_fit = np.argmin(crystal_res) + # ind_best_fit = np.argmax(phase_weights) + + phase_residual = crystal_res[ind_best_fit] + sub = np.logical_not(self.crystal_identity[:, 0] == ind_best_fit) + phase_weights[sub] = 0.0 + + # Estimate reliability as difference between best fit and 2nd best fit + crystal_res = np.sort(crystal_res) + phase_reliability = crystal_res[1] - crystal_res[0] + + else: + # Allow all crystals and orientation matches in the pattern + inds_solve = np.ones(self.num_fits, dtype="bool") + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:, inds_solve], + obs, + ) + + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): + phase_weights[inds_solve] = phase_weights_cand + phase_residual = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + + # Estimate the phase reliability + inds_solve = np.ones(self.num_fits, dtype="bool") + inds_solve[phase_weights > 1e-8] = False + + if np.all(inds_solve == False): + phase_reliability = 0.0 + else: + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:, inds_solve], + obs, + ) + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): + phase_residual_2nd = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + + phase_weights_cand, phase_residual_cand = nnls( + basis[:, inds_solve], + obs, + ) + phase_reliability = phase_residual_2nd - phase_residual + + except: + phase_weights = np.zeros(self.num_fits) + phase_residual = np.sqrt(np.sum(intensity**2)) + phase_reliability = 0.0 + + if verbose: + ind_max = np.argmax(phase_weights) + # print() + print("\033[1m" + "phase_weight or_ind name" + "\033[0m") + # print() + for a0 in range(self.num_fits): + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] + line = "{:>12} {:>8} {:<12}".format( + f"{phase_weights[a0]:.2f}", m, self.crystal_names[c] + ) + if a0 == ind_max: + print("\033[1m" + line + "\033[0m") + else: + print(line) + print("----------------------------") + line = "{:>12} {:>15}".format(f"{sum(phase_weights):.2f}", "fit total") + print("\033[1m" + line + "\033[0m") + line = "{:>12} {:>15}".format(f"{phase_residual:.2f}", "fit residual") + print(line) + + # Plotting + if plot_result: + # fig, ax = plt.subplots(figsize=figsize) + fig = plt.figure(figsize=figsize) + # if plot_layout == 0: + # ax_x = fig.add_axes( + # [0.0+figbound[0], 0.0, 0.4-2*+figbound[0], 1.0]) + ax = fig.add_axes([0.0, 0.0, 0.66, 1.0]) + ax_leg = fig.add_axes([0.68, 0.0, 0.3, 1.0]) + + if plot_correlation_radius: + # plot the experimental radii + t = np.linspace(0, 2 * np.pi, 91, endpoint=True) + ct = np.cos(t) * corr_kernel_size + st = np.sin(t) * corr_kernel_size + for a0 in range(qx.shape[0]): + ax.plot( + qy[a0] + st, + qx[a0] + ct, + color="k", + linewidth=1, + ) + + # plot the experimental peaks + ax.scatter( + qy0, + qx0, + # s = scale_markers_experiment * intensity0, + s=scale_markers_experiment + * bragg_peaks.data["intensity"][np.logical_not(keep)], + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) + ax.scatter( + qy, + qx, + # s = scale_markers_experiment * intensity, + s=scale_markers_experiment * bragg_peaks.data["intensity"][keep], + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) + # legend + if k_max is None: + k_max = np.max(self.k_max) + dx_leg = -0.05 * k_max + dy_leg = 0.04 * k_max + text_params = { + "va": "center", + "ha": "left", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 12, + } + if plot_correlation_radius: + ax_leg.plot( + 0 + st * 0.5, + -dx_leg + ct * 0.5, + color="k", + linewidth=1, + ) + ax_leg.scatter( + 0, + 0, + s=200, + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) + ax_leg.text(dy_leg, 0, "Experimental peaks", **text_params) + if plot_correlation_radius: + ax_leg.text(dy_leg, -dx_leg, "Correlation radius", **text_params) + + # plot calculated diffraction patterns + uvals = phase_colors.copy() + uvals[:, 3] = 0.3 + # uvals = np.array(( + # (1.0,0.0,0.0,0.2), + # (0.0,0.8,1.0,0.2), + # (0.0,0.6,0.0,0.2), + # (1.0,0.0,1.0,0.2), + # (0.0,0.2,1.0,0.2), + # (1.0,0.8,0.0,0.2), + # )) + mvals = [ + "v", + "^", + "<", + ">", + "d", + "s", + ] + + count_leg = 0 + for a0 in range(self.num_fits): + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] + + if ( + crystal_inds_plot == None + or np.min(np.abs(c - crystal_inds_plot)) == 0 + ): + qx_fit = library_peaks[a0].data["qx"] + qy_fit = library_peaks[a0].data["qy"] + + if allow_strain: + m_strain = m_strains[a0] + # Transformed peak positions + qx_copy = qx_fit.copy() + qy_copy = qy_fit.copy() + qx_fit = qx_copy * m_strain[0, 0] + qy_copy * m_strain[1, 0] + qy_fit = qx_copy * m_strain[0, 1] + qy_copy * m_strain[1, 1] + + int_fit = library_int[a0] + matches_fit = library_matches[a0] + + if plot_only_nonzero_phases is False or phase_weights[a0] > 0: + # if np.mod(m,2) == 0: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s=scale_markers_calculated * int_fit[matches_fit], + marker=mvals[c], + facecolor=phase_colors[c, :], + ) + if plot_unmatched_peaks: + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s=scale_markers_calculated + * int_fit[np.logical_not(matches_fit)], + marker=mvals[c], + facecolor=phase_colors[c, :], + ) + + # legend + if m == 0: + ax_leg.text( + dy_leg, + (count_leg + 1) * dx_leg, + self.crystal_names[c], + **text_params, + ) + ax_leg.scatter( + 0, + (count_leg + 1) * dx_leg, + s=200, + marker=mvals[c], + facecolor=phase_colors[c, :], + ) + count_leg += 1 + # else: + # ax.scatter( + # qy_fit[matches_fit], + # qx_fit[matches_fit], + # s = scale_markers_calculated * int_fit[matches_fit], + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # # facecolors = (1,1,1,0.5), + # linewidth = 2, + # ) + # if plot_unmatched_peaks: + # ax.scatter( + # qy_fit[np.logical_not(matches_fit)], + # qx_fit[np.logical_not(matches_fit)], + # s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (1,1,1,0.5), + # linewidth = 2, + # ) + + # # legend + # ax_leg.scatter( + # 0, + # dx_leg*(a0+1), + # s = 200, + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # # facecolors = (1,1,1,0.5), + # ) + + # appearance + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((k_max, -k_max)) + + ax_leg.set_xlim((-0.1 * k_max, 0.4 * k_max)) + ax_leg.set_ylim((-0.5 * k_max, 0.5 * k_max)) + ax_leg.set_axis_off() + + if returnfig: + return phase_weights, phase_residual, phase_reliability, int_total, fig, ax + else: + return phase_weights, phase_residual, phase_reliability, int_total def quantify_phase( self, - pointlistarray, - tolerance_distance=0.08, - method="nnls", - intensity_power=0, - mask_peaks=None, + pointlistarray: PointListArray, + corr_kernel_size=0.04, + sigma_excitation_error=0.02, + precession_angle_degrees=None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max=None, + max_number_patterns=2, + single_phase=False, + allow_strain=True, + strain_iterations=3, + strain_max=0.02, + include_false_positives=True, + weight_false_positives=1.0, + progress_bar=True, ): """ - Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. + Quantify phase of all diffraction patterns. - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from + Parameters + ---------- + pointlistarray: (PointListArray) + Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error: (float) + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees: (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + power_intensity: (float) + Power for scaling the correlation intensity as a function of simulated peak intensity. + power_intensity_experiment: (float): + Power for scaling the correlation intensity as a function of experimental peak intensity. + k_max: (float) + Max k values included in fits, for both x and y directions. + max_number_patterns: int + Max number of orientations which can be included in a match. + single_phase: bool + Set to true to force result to output only the best-fit phase (minimum intensity residual). + allow_strain: bool, + Allow the simulated diffraction patterns to be distorted to improve the matches. + strain_iterations: int + Number of pattern position refinement iterations. + strain_max: float + Maximum strain fraction allowed - this value should be low, typically a few percent (~0.02). + include_false_positives: bool + Penalize patterns which generate false positive peaks. + weight_false_positives: float + Weight strength of false positive peaks. + progressbar: bool + Display progress. + + Returns + ----------- + + """ + + # init results arrays + self.phase_weights = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + self.num_fits, + ) + ) + self.phase_residuals = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) + self.phase_reliability = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) + self.int_total = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) + self.single_phase = single_phase + + for rx, ry in tqdmnd( + *pointlistarray.shape, + desc="Quantifying Phase", + unit=" PointList", + disable=not progress_bar, + ): + # calculate phase weights + ( + phase_weights, + phase_residual, + phase_reliability, + int_peaks, + ) = self.quantify_single_pattern( + pointlistarray=pointlistarray, + xy_position=(rx, ry), + corr_kernel_size=corr_kernel_size, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, + power_intensity=power_intensity, + power_intensity_experiment=power_intensity_experiment, + k_max=k_max, + max_number_patterns=max_number_patterns, + single_phase=single_phase, + allow_strain=allow_strain, + strain_iterations=strain_iterations, + strain_max=strain_max, + include_false_positives=include_false_positives, + weight_false_positives=weight_false_positives, + plot_result=False, + verbose=False, + returnfig=False, + ) + self.phase_weights[rx, ry] = phase_weights + self.phase_residuals[rx, ry] = phase_residual + self.phase_reliability[rx, ry] = phase_reliability + self.int_total[rx, ry] = int_peaks + + def plot_phase_weights( + self, + weight_range=(0.0, 1.0), + weight_normalize=False, + total_intensity_normalize=True, + cmap="gray", + show_ticks=False, + show_axes=True, + layout=0, + figsize=(6, 6), + returnfig=False, + ): + """ + Plot the individual phase weight maps and residuals. + + Parameters + ---------- + weight_range: (float, float) + Plotting weight range. + weight_normalize: bool + Normalize weights before plotting. + total_intensity_normalize: bool + Normalize the total intensity. + cmap: matplotlib.cm.cmap + Colormap to use for plots. + show_ticks: bool + Show ticks on plots. + show_axes: bool + Show axes. + layout: int + Layout type for figures. + figsize: (float,float) + Size of figure panel. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + fig,ax: (optional) + Figure and axes handles. - Details: """ - if isinstance(pointlistarray, PointListArray): - phase_weights = np.zeros( + + # Normalization if required to total DF peak intensity + phase_weights = self.phase_weights.copy() + phase_residuals = self.phase_residuals.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:, :, a0][sub] /= self.int_total[sub] + phase_residuals[sub] /= self.int_total[sub] + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(phase_weights, axis=2)) + else: + scale = 1 + weight_range = np.array(weight_range) * scale + + # plotting + if layout == 0: + fig, ax = plt.subplots( + 1, + self.num_crystals + 1, + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) + elif layout == 1: + fig, ax = plt.subplots( + self.num_crystals + 1, + 1, + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) + + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:, 0] == a0 + im = np.sum(phase_weights[:, :, sub], axis=2) + im = np.clip( + (im - weight_range[0]) / (weight_range[1] - weight_range[0]), 0, 1 + ) + ax[a0].imshow( + im, + vmin=0, + vmax=1, + cmap=cmap, + ) + ax[a0].set_title( + self.crystal_names[a0], + fontsize=16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # plot residuals + im = np.clip( + (phase_residuals - weight_range[0]) / (weight_range[1] - weight_range[0]), + 0, + 1, + ) + ax[self.num_crystals].imshow( + im, + vmin=0, + vmax=1, + cmap=cmap, + ) + ax[self.num_crystals].set_title( + "Residuals", + fontsize=16, + ) + if not show_ticks: + ax[self.num_crystals].set_xticks([]) + ax[self.num_crystals].set_yticks([]) + if not show_axes: + ax[self.num_crystals].set_axis_off() + + if returnfig: + return fig, ax + + def plot_phase_maps( + self, + weight_threshold=0.5, + weight_normalize=True, + total_intensity_normalize=True, + plot_combine=False, + crystal_inds_plot=None, + phase_colors=None, + show_ticks=False, + show_axes=True, + layout=0, + figsize=(6, 6), + return_phase_estimate=False, + return_rgb_images=False, + returnfig=False, + ): + """ + Plot the individual phase weight maps and residuals. + + Parameters + ---------- + weight_threshold: float + Threshold for showing each phase. + weight_normalize: bool + Normalize weights before plotting. + total_intensity_normalize: bool + Normalize the total intensity. + plot_combine: bool + Combine all figures into a single plot. + crystal_inds_plot: (tuple of ints) + Which crystals to plot phase maps for. + phase_colors: np.array + (Nx3) shaped array giving the colors for each phase + show_ticks: bool + Show ticks on plots. + show_axes: bool + Show axes. + layout: int + Layout type for figures. + figsize: (float,float) + Size of figure panel. + return_phase_estimate: bool + Return the phase estimate array. + return_rgb_images: bool + Return the rgb images. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + im_all: (np.array, optional) + images showing phase maps. + im_rgb, im_rgb_all: (np.array, optional) + rgb colored output images, possibly combined + fig,ax: (optional) + Figure and axes handles. + + """ + + if phase_colors is None: + phase_colors = np.array( ( - pointlistarray.shape[0], - pointlistarray.shape[1], - np.sum([map.num_matches for map in self.orientation_maps]), + (1.0, 0.0, 0.0), + (0.0, 0.8, 1.0), + (0.0, 0.8, 0.0), + (1.0, 0.0, 1.0), + (0.0, 0.4, 1.0), + (1.0, 0.8, 0.0), ) ) - phase_residuals = np.zeros(pointlistarray.shape) - for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): - ( - _, - phase_weight, - phase_residual, - crystal_identity, - ) = self.quantify_phase_pointlist( - pointlistarray, - position=[Rx, Ry], - tolerance_distance=tolerance_distance, - method=method, - intensity_power=intensity_power, - mask_peaks=mask_peaks, + + phase_weights = self.phase_weights.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:, :, a0][sub] /= self.int_total[sub] + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(phase_weights, axis=2)) + else: + scale = 1 + weight_threshold = weight_threshold * scale + + # init + im_all = np.zeros( + ( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + ) + ) + im_rgb_all = np.zeros( + ( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + 3, + ) + ) + + # phase weights over threshold + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:, 0] == a0 + im = np.sum(phase_weights[:, :, sub], axis=2) + im_all[a0] = np.maximum(im - weight_threshold, 0) + + # estimate compositions + im_sum = np.sum(im_all, axis=0) + sub = im_sum > 0.0 + for a0 in range(self.num_crystals): + im_all[a0][sub] /= im_sum[sub] + + for a1 in range(3): + im_rgb_all[a0, :, :, a1] = im_all[a0] * phase_colors[a0, a1] + + if plot_combine: + if crystal_inds_plot is None: + im_rgb = np.sum(im_rgb_all, axis=0) + else: + im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis=0) + + im_rgb = np.clip(im_rgb, 0, 1) + + fig, ax = plt.subplots(1, 1, figsize=figsize) + ax.imshow( + im_rgb, + ) + ax.set_title( + "Phase Maps", + fontsize=16, + ) + if not show_ticks: + ax.set_xticks([]) + ax.set_yticks([]) + if not show_axes: + ax.set_axis_off() + + else: + # plotting + if layout == 0: + fig, ax = plt.subplots( + 1, + self.num_crystals, + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), ) - phase_weights[Rx, Ry, :] = phase_weight - phase_residuals[Rx, Ry] = phase_residual - self.phase_weights = phase_weights - self.phase_residuals = phase_residuals - self.crystal_identity = crystal_identity - return + elif layout == 1: + fig, ax = plt.subplots( + self.num_crystals, + 1, + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) + + for a0 in range(self.num_crystals): + ax[a0].imshow( + im_rgb_all[a0], + ) + ax[a0].set_title( + self.crystal_names[a0], + fontsize=16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # All possible returns + if return_phase_estimate: + if returnfig: + return im_all, fig, ax + else: + return im_all + elif return_rgb_images: + if plot_combine: + if returnfig: + return im_rgb, fig, ax + else: + return im_rgb + else: + if returnfig: + return im_rgb_all, fig, ax + else: + return im_rgb_all else: - return TypeError("pointlistarray must be of type pointlistarray.") - return + if returnfig: + return fig, ax - def quantify_phase_pointlist( + def plot_dominant_phase( self, - pointlistarray, - position, - method="nnls", - tolerance_distance=0.08, - intensity_power=0, - mask_peaks=None, + use_correlation_scores=False, + reliability_range=(0.0, 1.0), + sigma=0.0, + phase_colors=None, + ticks=True, + figsize=(6, 6), + legend_add=True, + legend_fraction=0.2, + print_fractions=False, + returnfig=True, ): """ - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - position (tuple/list): Position of pointlist in pointlistarray - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from - - Returns: - pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - phase_weights (np.ndarray): Weights of each phase - phase_residuals (np.ndarray): Residuals - crystal_identity (list): List of lists, where the each entry represents the position in the - crystal and orientation match that is associated with the phase - weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - the first entry [0,0] in phase weights is associated with the first crystal - the first match within that crystal. [0,1] is the first crystal and the - second match within that crystal. + Plot a combined figure showing the primary phase at each probe position. + Mask by the reliability index (best match minus 2nd best match). + + Parameters + ---------- + use_correlation_scores: bool + Set to True to use correlation scores instead of reliabiltiy from intensity residuals. + reliability_range: (float, float) + Plotting intensity range + sigma: float + Smoothing in units of probe position. + phase_colors: np.array + (N,3) shaped array giving colors of all phases + ticks: bool + Show ticks on plots. + figsize: (float,float) + Size of output figure + legend_add: bool + Add legend to plot + legend_fraction: float + Fractional size of legend in plot. + print_fractions: bool + Print the estimated fraction of all phases. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + fig,ax: (optional) + Figure and axes handles. + """ - # Things to add: - # 1. Better cost for distance from peaks in pointlists - # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - pointlist = pointlistarray.get_pointlist(position[0], position[1]) - pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0) - pointlist.remove(pl_mask) - # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in + if phase_colors is None: + phase_colors = np.array( + [ + [1.0, 0.9, 0.6], + [1, 0, 0], + [0, 0.7, 0], + [0, 0.7, 1], + [1, 0, 1], + ] + ) - if intensity_power == 0: - pl_intensities = np.ones(pointlist["intensity"].shape) - else: - pl_intensities = pointlist["intensity"] ** intensity_power - # Prepare matches for modeling - pointlist_peak_matches = [] - crystal_identity = [] - - for c in range(len(self.crystals)): - for m in range(self.orientation_maps[c].num_matches): - crystal_identity.append([c, m]) - phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape)) - bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - self.orientation_maps[c].get_orientation(position[0], position[1]), - ind_orientation=m, - ) - # Find the best match peak within tolerance_distance and add value in the right position - for d in range(pointlist["qx"].shape[0]): - distances = [] - for p in range(bragg_peaks_fit["qx"].shape[0]): - distances.append( - np.sqrt( - (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2 - + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2 - ) - ) - ind = np.where(distances == np.min(distances))[0][0] + # init arrays + scan_shape = self.phase_weights.shape[:2] + phase_map = np.zeros(scan_shape) + phase_corr = np.zeros(scan_shape) + phase_corr_2nd = np.zeros(scan_shape) + phase_sig = np.zeros((self.num_crystals, scan_shape[0], scan_shape[1])) - # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - if distances[ind] <= tolerance_distance: - ## Somewhere in this if statement is probably where better distances from the peak should be coded in - if ( - intensity_power == 0 - ): # This could potentially be a different intensity_power arg - phase_peak_match_intensities[d] = 1 ** ( - (tolerance_distance - distances[ind]) - / tolerance_distance - ) - else: - phase_peak_match_intensities[d] = bragg_peaks_fit[ - "intensity" - ][ind] ** ( - (tolerance_distance - distances[ind]) - / tolerance_distance - ) - else: - ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - continue - - pointlist_peak_matches.append(phase_peak_match_intensities) - pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - pointlist_peak_intensity_matches = ( - pointlist_peak_intensity_matches.reshape( - pl_intensities.shape[0], - pointlist_peak_intensity_matches.shape[-1], - ) + if use_correlation_scores: + # Calculate scores from highest correlation match + for a0 in range(self.num_crystals): + phase_sig[a0] = np.maximum( + phase_sig[a0], + np.max(self.crystals[a0].orientation_map.corr, axis=2), ) + else: + # sum up phase weights by crystal type + for a0 in range(self.num_fits): + ind = self.crystal_identity[a0, 0] + phase_sig[ind] += self.phase_weights[:, :, a0] - if len(pointlist["qx"]) > 0: - if mask_peaks is not None: - for i in range(len(mask_peaks)): - if mask_peaks[i] == None: # noqa: E711 - continue - inds_mask = np.where( - pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 - )[0] - for mask in range(len(inds_mask)): - pointlist_peak_intensity_matches[inds_mask[mask], i] = 0 - - if method == "nnls": - phase_weights, phase_residuals = nnls( - pointlist_peak_intensity_matches, pl_intensities + # smoothing of the outputs + if sigma > 0.0: + for a0 in range(self.num_crystals): + phase_sig[a0] = gaussian_filter( + phase_sig[a0], + sigma=sigma, + mode="nearest", ) - elif method == "lstsq": - phase_weights, phase_residuals, rank, singluar_vals = lstsq( - pointlist_peak_intensity_matches, pl_intensities, rcond=-1 + # find highest correlation score for each crystal and match index + for a0 in range(self.num_crystals): + sub = phase_sig[a0] > phase_corr + phase_map[sub] = a0 + phase_corr[sub] = phase_sig[a0][sub] + + if self.single_phase: + phase_scale = np.clip( + (self.phase_reliability - reliability_range[0]) + / (reliability_range[1] - reliability_range[0]), + 0, + 1, + ) + + else: + # find the second correlation score for each crystal and match index + for a0 in range(self.num_crystals): + corr = phase_sig[a0].copy() + corr[phase_map == a0] = 0.0 + sub = corr > phase_corr_2nd + phase_corr_2nd[sub] = corr[sub] + + # Estimate the reliability + phase_rel = phase_corr - phase_corr_2nd + phase_scale = np.clip( + (phase_rel - reliability_range[0]) + / (reliability_range[1] - reliability_range[0]), + 0, + 1, + ) + + # Print the total area of fraction of each phase + if print_fractions: + phase_mask = phase_scale >= 0.5 + phase_total = np.sum(phase_mask) + + print("Phase Fractions") + print("---------------") + for a0 in range(self.num_crystals): + phase_frac = np.sum((phase_map == a0) * phase_mask) / phase_total + + print(self.crystal_names[a0] + " - " + f"{phase_frac*100:.4f}" + "%") + + self.phase_rgb = np.zeros((scan_shape[0], scan_shape[1], 3)) + for a0 in range(self.num_crystals): + sub = phase_map == a0 + for a1 in range(3): + self.phase_rgb[:, :, a1][sub] = phase_colors[a0, a1] * phase_scale[sub] + # normalize + # self.phase_rgb = np.clip( + # (self.phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), + # 0,1) + + fig = plt.figure(figsize=figsize) + if legend_add: + width = 1 + + ax = fig.add_axes((0, legend_fraction, 1, 1 - legend_fraction)) + ax_leg = fig.add_axes((0, 0, 1, legend_fraction)) + + for a0 in range(self.num_crystals): + ax_leg.scatter( + a0 * width, + 0, + s=200, + marker="s", + edgecolor=(0, 0, 0, 1), + facecolor=phase_colors[a0], ) - phase_residuals = np.sum(phase_residuals) - else: - raise ValueError(method + " Not yet implemented. Try nnls or lstsq.") + ax_leg.text( + a0 * width + 0.1, + 0, + self.crystal_names[a0], + fontsize=16, + verticalalignment="center", + ) + ax_leg.axis("off") + ax_leg.set_xlim( + ( + width * -0.5, + width * (self.num_crystals + 0.5), + ) + ) + else: - phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) - phase_residuals = np.NaN - return ( - pointlist_peak_intensity_matches, - phase_weights, - phase_residuals, - crystal_identity, + ax = fig.add_axes((0, 0, 1, 1)) + + ax.imshow( + self.phase_rgb, ) - # def plot_peak_matches( - # self, - # pointlistarray, - # position, - # tolerance_distance, - # ind_orientation, - # pointlist_peak_intensity_matches, - # ): - # """ - # A method to view how the tolerance distance impacts the peak matches associated with - # the quantify_phase_pointlist method. - - # Args: - # pointlistarray, - # position, - # tolerance_distance - # pointlist_peak_intensity_matches - # """ - # pointlist = pointlistarray.get_pointlist(position[0],position[1]) - - # for m in range(pointlist_peak_intensity_matches.shape[1]): - # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern( - # self.orientation_maps[m].get_orientation(position[0], position[1]), - # ind_orientation = ind_orientation - # ) - # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m]) - - # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) - # ax1 = plot_diffraction_pattern(pointlist,) - # return + if ticks is False: + ax.set_xticks([]) + ax.set_yticks([]) + + if returnfig: + return fig, ax diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 47df2e6ca..5bc16c1bc 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -8,7 +8,6 @@ from scipy.signal import medfilt from scipy.ndimage import gaussian_filter from scipy.ndimage import distance_transform_edt -from skimage.morphology import dilation, erosion import warnings import numpy as np @@ -393,7 +392,7 @@ def plot_scattering_intensity( k_step=0.001, k_broadening=0.0, k_power_scale=0.0, - int_power_scale=0.5, + int_power_scale=1.0, int_scale=1.0, remove_origin=True, bragg_peaks=None, @@ -538,6 +537,7 @@ def plot_orientation_zones( proj_dir_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, tol_den=10, marker_size: float = 20, + plot_zone_axis_labels: bool = True, plot_limit: Union[list, tuple, np.ndarray] = np.array([-1.1, 1.1]), figsize: Union[list, tuple, np.ndarray] = (8, 8), returnfig: bool = False, @@ -553,6 +553,7 @@ def plot_orientation_zones( dir_proj (float): projection direction, either [elev azim] or normal vector Default is mean vector of self.orientation_zone_axis_range rows marker_size (float): size of markers + plot_zone_axis_labels (bool): plot the zone axis labels plot_limit (float): x y z plot limits, default is [0, 1.05] figsize (2 element float): size scaling of figure axes returnfig (bool): set to True to return figure and axes handles @@ -727,47 +728,47 @@ def plot_orientation_zones( zorder=10, ) - text_scale_pos = 1.2 - text_params = { - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "color": "k", - "size": 16, - } - # 'ha': 'center', - - ax.text( - self.orientation_vecs[inds[0], 1] * text_scale_pos, - self.orientation_vecs[inds[0], 0] * text_scale_pos, - self.orientation_vecs[inds[0], 2] * text_scale_pos, - label_0, - None, - zorder=11, - ha="center", - **text_params, - ) - if self.orientation_full is False and self.orientation_half is False: - ax.text( - self.orientation_vecs[inds[1], 1] * text_scale_pos, - self.orientation_vecs[inds[1], 0] * text_scale_pos, - self.orientation_vecs[inds[1], 2] * text_scale_pos, - label_1, - None, - zorder=12, - ha="center", - **text_params, - ) + if plot_zone_axis_labels: + text_scale_pos = 1.2 + text_params = { + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 16, + } + # 'ha': 'center', ax.text( - self.orientation_vecs[inds[2], 1] * text_scale_pos, - self.orientation_vecs[inds[2], 0] * text_scale_pos, - self.orientation_vecs[inds[2], 2] * text_scale_pos, - label_2, + self.orientation_vecs[inds[0], 1] * text_scale_pos, + self.orientation_vecs[inds[0], 0] * text_scale_pos, + self.orientation_vecs[inds[0], 2] * text_scale_pos, + label_0, None, - zorder=13, + zorder=11, ha="center", **text_params, ) + if self.orientation_full is False and self.orientation_half is False: + ax.text( + self.orientation_vecs[inds[1], 1] * text_scale_pos, + self.orientation_vecs[inds[1], 0] * text_scale_pos, + self.orientation_vecs[inds[1], 2] * text_scale_pos, + label_1, + None, + zorder=12, + ha="center", + **text_params, + ) + ax.text( + self.orientation_vecs[inds[2], 1] * text_scale_pos, + self.orientation_vecs[inds[2], 0] * text_scale_pos, + self.orientation_vecs[inds[2], 2] * text_scale_pos, + label_2, + None, + zorder=13, + ha="center", + **text_params, + ) # ax.scatter( # xs=self.g_vec_all[0,:], @@ -848,6 +849,7 @@ def plot_orientation_plan( bragg_peaks = self.generate_diffraction_pattern( orientation_matrix=self.orientation_rotation_matrices[index_plot, :], sigma_excitation_error=self.orientation_kernel_size / 3, + precession_angle_degrees=self.orientation_precession_angle_degrees, ) plot_diffraction_pattern( @@ -947,12 +949,24 @@ def plot_diffraction_pattern( scale_markers: float = 500, scale_markers_compare: Optional[float] = None, power_markers: float = 1, + power_markers_compare: float = 1, + color=(0.0, 0.0, 0.0), + color_compare=None, + facecolor=None, + facecolor_compare=(0.0, 0.7, 1.0), + edgecolor=None, + edgecolor_compare=None, + linewidth=1, + linewidth_compare=1, + marker="+", + marker_compare="o", plot_range_kx_ky: Optional[Union[list, tuple, np.ndarray]] = None, add_labels: bool = True, shift_labels: float = 0.08, shift_marker: float = 0.005, min_marker_size: float = 1e-6, max_marker_size: float = 1000, + show_axes: bool = True, figsize: Union[list, tuple, np.ndarray] = (12, 6), returnfig: bool = False, input_fig_handle=None, @@ -988,9 +1002,7 @@ def plot_diffraction_pattern( if power_markers == 2: marker_size = scale_markers * bragg_peaks.data["intensity"] else: - marker_size = scale_markers * ( - bragg_peaks.data["intensity"] ** (power_markers / 2) - ) + marker_size = scale_markers * (bragg_peaks.data["intensity"] ** power_markers) # Apply marker size limits to primary plot marker_size = np.clip(marker_size, min_marker_size, max_marker_size) @@ -1012,7 +1024,7 @@ def plot_diffraction_pattern( else: marker_size_compare = np.clip( scale_markers_compare - * (bragg_peaks_compare.data["intensity"] ** (power_markers / 2)), + * (bragg_peaks_compare.data["intensity"] ** power_markers), min_marker_size, max_marker_size, ) @@ -1021,19 +1033,29 @@ def plot_diffraction_pattern( bragg_peaks_compare.data["qy"], bragg_peaks_compare.data["qx"], s=marker_size_compare, - marker="o", - facecolor=[0.0, 0.7, 1.0], + marker=marker_compare, + facecolor=facecolor_compare, + edgecolor=edgecolor_compare, + color=color_compare, + linewidth=linewidth_compare, ) ax.scatter( bragg_peaks.data["qy"], bragg_peaks.data["qx"], s=marker_size, - marker="+", - facecolor="k", + marker=marker, + facecolor=facecolor, + edgecolor=edgecolor, + color=color, + linewidth=linewidth, ) - ax.set_xlabel("$q_y$ [Ã…$^{-1}$]") - ax.set_ylabel("$q_x$ [Ã…$^{-1}$]") + if show_axes: + ax.set_xlabel("$q_y$ [Ã…$^{-1}$]") + ax.set_ylabel("$q_x$ [Ã…$^{-1}$]") + else: + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) if plot_range_kx_ky is not None: plot_range_kx_ky = np.array(plot_range_kx_ky) @@ -1097,6 +1119,7 @@ def plot_orientation_maps( dir_in_plane_degrees: float = 0.0, corr_range: np.ndarray = np.array([0, 5]), corr_normalize: bool = True, + show_legend: bool = True, scale_legend: bool = None, figsize: Union[list, tuple, np.ndarray] = (16, 5), figbound: Union[list, tuple, np.ndarray] = (0.01, 0.005), @@ -1118,6 +1141,7 @@ def plot_orientation_maps( dir_in_plane_degrees (float): In-plane angle to plot in degrees. Default is 0 / x-axis / vertical down. corr_range (np.ndarray): Correlation intensity range for the plot corr_normalize (bool): If true, set mean correlation to 1. + show_legend (bool): Show the legend scale_legend (float): 2 elements, x and y scaling of legend panel figsize (array): 2 elements defining figure size figbound (array): 2 elements defining figure boundary @@ -1392,189 +1416,192 @@ def plot_orientation_maps( ax_x.axis("off") ax_z.axis("off") - # Triangulate faces - p = self.orientation_vecs[:, (1, 0, 2)] - tri = mtri.Triangulation( - self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, - self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, - ) - # convert rgb values of pixels to faces - rgb_faces = ( - rgb_legend[tri.triangles[:, 0], :] - + rgb_legend[tri.triangles[:, 1], :] - + rgb_legend[tri.triangles[:, 2], :] - ) / 3 - # Add triangulated surface plot to axes - pc = art3d.Poly3DCollection( - p[tri.triangles], - facecolors=rgb_faces, - alpha=1, - ) - pc.set_antialiased(False) - ax_l.add_collection(pc) - - if plot_limit is None: - plot_limit = np.array( - [ - [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])], - [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])], - ] - ) - # plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.5 + np.mean( - # plot_limit, axis=0 - # ) - plot_limit[:, 0] = ( - plot_limit[:, 0] - np.mean(plot_limit[:, 0]) - ) * 1.5 + np.mean(plot_limit[:, 0]) - plot_limit[:, 1] = ( - plot_limit[:, 2] - np.mean(plot_limit[:, 1]) - ) * 1.5 + np.mean(plot_limit[:, 1]) - plot_limit[:, 2] = ( - plot_limit[:, 1] - np.mean(plot_limit[:, 2]) - ) * 1.1 + np.mean(plot_limit[:, 2]) - - # ax_l.view_init(elev=el, azim=az) - # Appearance - ax_l.invert_yaxis() - if swap_axes_xy_limits: - ax_l.axes.set_xlim3d(left=plot_limit[0, 0], right=plot_limit[1, 0]) - ax_l.axes.set_ylim3d(bottom=plot_limit[0, 1], top=plot_limit[1, 1]) - ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) - else: - ax_l.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1]) - ax_l.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0]) - ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) - axisEqual3D(ax_l) - if camera_dist is not None: - ax_l.dist = camera_dist - ax_l.axis("off") - - # Add text labels - text_scale_pos = 0.1 - text_params = { - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "color": "k", - "size": 14, - } - format_labels = "{0:.2g}" - vec = self.orientation_vecs[inds_legend[0], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_0[0]) - + " " - + format_labels.format(label_0[1]) - + " " - + format_labels.format(label_0[2]) - + "]", - None, - zorder=11, - ha="center", - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_0[0]) - + " " - + format_labels.format(label_0[1]) - + " " - + format_labels.format(label_0[2]) - + " " - + format_labels.format(label_0[3]) - + "]", - None, - zorder=11, - ha="center", - **text_params, - ) - vec = self.orientation_vecs[inds_legend[1], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_1[0]) - + " " - + format_labels.format(label_1[1]) - + " " - + format_labels.format(label_1[2]) - + "]", - None, - zorder=12, - ha=ha_1, - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_1[0]) - + " " - + format_labels.format(label_1[1]) - + " " - + format_labels.format(label_1[2]) - + " " - + format_labels.format(label_1[3]) - + "]", - None, - zorder=12, - ha=ha_1, - **text_params, + if show_legend: + # Triangulate faces + p = self.orientation_vecs[:, (1, 0, 2)] + tri = mtri.Triangulation( + self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, + self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, ) - vec = self.orientation_vecs[inds_legend[2], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_2[0]) - + " " - + format_labels.format(label_2[1]) - + " " - + format_labels.format(label_2[2]) - + "]", - None, - zorder=13, - ha=ha_2, - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_2[0]) - + " " - + format_labels.format(label_2[1]) - + " " - + format_labels.format(label_2[2]) - + " " - + format_labels.format(label_2[3]) - + "]", - None, - zorder=13, - ha=ha_2, - **text_params, + # convert rgb values of pixels to faces + rgb_faces = ( + rgb_legend[tri.triangles[:, 0], :] + + rgb_legend[tri.triangles[:, 1], :] + + rgb_legend[tri.triangles[:, 2], :] + ) / 3 + # Add triangulated surface plot to axes + pc = art3d.Poly3DCollection( + p[tri.triangles], + facecolors=rgb_faces, + alpha=1, ) + pc.set_antialiased(False) + ax_l.add_collection(pc) + + if plot_limit is None: + plot_limit = np.array( + [ + [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])], + [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])], + ] + ) + # plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.5 + np.mean( + # plot_limit, axis=0 + # ) + plot_limit[:, 0] = ( + plot_limit[:, 0] - np.mean(plot_limit[:, 0]) + ) * 1.5 + np.mean(plot_limit[:, 0]) + plot_limit[:, 1] = ( + plot_limit[:, 2] - np.mean(plot_limit[:, 1]) + ) * 1.5 + np.mean(plot_limit[:, 1]) + plot_limit[:, 2] = ( + plot_limit[:, 1] - np.mean(plot_limit[:, 2]) + ) * 1.1 + np.mean(plot_limit[:, 2]) + + # ax_l.view_init(elev=el, azim=az) + # Appearance + ax_l.invert_yaxis() + if swap_axes_xy_limits: + ax_l.axes.set_xlim3d(left=plot_limit[0, 0], right=plot_limit[1, 0]) + ax_l.axes.set_ylim3d(bottom=plot_limit[0, 1], top=plot_limit[1, 1]) + ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) + else: + ax_l.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1]) + ax_l.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0]) + ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) + axisEqual3D(ax_l) + if camera_dist is not None: + ax_l.dist = camera_dist + ax_l.axis("off") + + # Add text labels + text_scale_pos = 0.1 + text_params = { + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 14, + } + format_labels = "{0:.2g}" + vec = self.orientation_vecs[inds_legend[0], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_0[0]) + + " " + + format_labels.format(label_0[1]) + + " " + + format_labels.format(label_0[2]) + + "]", + None, + zorder=11, + ha="center", + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_0[0]) + + " " + + format_labels.format(label_0[1]) + + " " + + format_labels.format(label_0[2]) + + " " + + format_labels.format(label_0[3]) + + "]", + None, + zorder=11, + ha="center", + **text_params, + ) + vec = self.orientation_vecs[inds_legend[1], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_1[0]) + + " " + + format_labels.format(label_1[1]) + + " " + + format_labels.format(label_1[2]) + + "]", + None, + zorder=12, + ha=ha_1, + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_1[0]) + + " " + + format_labels.format(label_1[1]) + + " " + + format_labels.format(label_1[2]) + + " " + + format_labels.format(label_1[3]) + + "]", + None, + zorder=12, + ha=ha_1, + **text_params, + ) + vec = self.orientation_vecs[inds_legend[2], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_2[0]) + + " " + + format_labels.format(label_2[1]) + + " " + + format_labels.format(label_2[2]) + + "]", + None, + zorder=13, + ha=ha_2, + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_2[0]) + + " " + + format_labels.format(label_2[1]) + + " " + + format_labels.format(label_2[2]) + + " " + + format_labels.format(label_2[3]) + + "]", + None, + zorder=13, + ha=ha_2, + **text_params, + ) - plt.show() + plt.show() + else: + ax_l.set_axis_off() images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2)) if self.pymatgen_available: @@ -1885,6 +1912,8 @@ def plot_clusters( for a0 in range(self.cluster_sizes.shape[0]): if self.cluster_sizes[a0] >= area_min: if outline_grains: + from skimage.morphology import erosion + im_grain[:] = False im_grain[ self.cluster_inds[a0][0, :], diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index ecfeaa1d2..e6c0070e1 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -2,6 +2,7 @@ _emd_hook = True +from py4DSTEM import is_package_lite from py4DSTEM.process.phase.dpc import DPC from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography @@ -11,6 +12,11 @@ from py4DSTEM.process.phase.parallax import Parallax from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography -from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer +from py4DSTEM.process.phase.xray_magnetic_ptychography import XRayMagneticPtychography +try: + from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc # fmt: on diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index be6332c74..8265c1325 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -225,6 +225,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_probe_overlaps: bool = True, rotation_real_space_degrees: float = None, @@ -266,6 +267,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_probe_overlaps: bool, optional @@ -479,12 +483,14 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 19f306188..975f6ac84 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -101,6 +101,9 @@ class MagneticPtychography( initial_scan_positions: np.ndarray, optional Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Ã…. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A verbose: bool, optional @@ -138,6 +141,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_type: str = "complex", verbose: bool = True, @@ -189,6 +193,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -203,6 +208,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_rotation: bool = True, maximize_divergence: bool = False, @@ -219,6 +225,7 @@ def preprocess( progress_bar: bool = True, object_fov_mask: np.ndarray = True, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -253,6 +260,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_rotation: bool, optional @@ -286,6 +296,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -365,10 +377,6 @@ def preprocess( f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." ) - dc_shapes = [dc.shape for dc in self._datacube] - if dc_shapes.count(dc_shapes[0]) != self._num_measurements: - raise ValueError("datacube intensities must be the same size.") - if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") @@ -543,12 +551,14 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -610,14 +620,17 @@ def preprocess( self._positions_px_all, dtype=xp_storage.float32 ) - for index in range(self._num_measurements): - idx_start = self._cum_probes_per_measurement[index] - idx_end = self._cum_probes_per_measurement[index + 1] + if center_positions_in_fov: + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] - positions_px = self._positions_px_all[idx_start:idx_end] - positions_px_com = positions_px.mean(0) - positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 - self._positions_px_all[idx_start:idx_end] = positions_px.copy() + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= ( + positions_px_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_all[idx_start:idx_end] = positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() @@ -1785,55 +1798,55 @@ def _visualize_last_iteration( if fig is None: fig = plt.figure(figsize=figsize) - if plot_probe or plot_fourier_probe: - # Object_e - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[0], - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + # Object_e + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0], + extent=extent, + cmap=cmap_e, + vmin=vmin_e, + vmax=vmax_e, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") - # Object_m - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1], - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Electrostatic potential") + elif self._object_type == "complex": + ax.set_title("Electrostatic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_m + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1], + extent=extent, + cmap=cmap_m, + vmin=vmin_m, + vmax=vmax_m, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Magnetic potential") - elif self._object_type == "complex": - ax.set_title("Magnetic phase") + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + if plot_probe or plot_fourier_probe: # Probe ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: @@ -1872,55 +1885,6 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - else: - # Object_e - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[0], - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Object_e - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1], - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Magnetic potential") - elif self._object_type == "complex": - ax.set_title("Magnetic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "error_iterations"): errors = np.array(self.error_iterations) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index b22d0a0bb..3bacf1870 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -107,6 +107,9 @@ class MixedstateMultislicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Ã…. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A theta_x: float @@ -159,6 +162,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, theta_x: float = 0, theta_y: float = 0, @@ -245,6 +249,7 @@ def __init__( self._object_type = object_type self._positions_mask = positions_mask self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._verbose = verbose self._preprocessed = False @@ -262,6 +267,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -278,6 +284,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -312,6 +319,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -348,6 +358,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -478,17 +490,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) @@ -524,15 +540,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 436338555..9b12d09e0 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -96,6 +96,9 @@ class MixedstatePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Ã…. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A positions_mask: np.ndarray, optional @@ -127,6 +130,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_type: str = "complex", positions_mask: np.ndarray = None, @@ -194,6 +198,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -208,6 +213,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -224,6 +230,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -258,6 +265,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -294,6 +304,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -424,17 +436,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) @@ -469,15 +485,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 6de8e7970..65a347b83 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -99,6 +99,9 @@ class MultislicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Ã…. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A theta_x: float @@ -149,6 +152,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, theta_x: float = None, theta_y: float = None, @@ -220,6 +224,7 @@ def __init__( self._object_type = object_type self._positions_mask = positions_mask self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._verbose = verbose self._preprocessed = False @@ -236,6 +241,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -252,6 +258,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -286,6 +293,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -322,6 +332,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -452,17 +464,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) @@ -498,15 +514,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index e12d6a133..e5768f3cc 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -19,13 +19,15 @@ AffineTransform, bilinear_kernel_density_estimate, bilinearly_interpolate_array, + calculate_aberration_gradient_basis, + generate_batches, lanczos_interpolate_array, lanczos_kernel_density_estimate, pixel_rolling_kernel_density_estimate, ) from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from py4DSTEM.visualize import return_scaled_histogram_ordering, show +from py4DSTEM.visualize import return_scaled_histogram_ordering from scipy.linalg import polar from scipy.ndimage import distance_transform_edt from scipy.optimize import minimize @@ -260,12 +262,14 @@ def preprocess( descan_correction_fit_function: str = None, defocus_guess: float = None, rotation_guess: float = None, + aligned_bf_image_guess: np.ndarray = None, plot_average_bf: bool = True, realspace_mask: np.ndarray = None, apply_realspace_mask_to_stack: bool = True, vectorized_com_calculation: bool = True, device: str = None, clear_fft_cache: bool = None, + max_batch_size: int = None, store_initial_arrays: bool = True, **kwargs, ): @@ -284,16 +288,18 @@ def preprocess( If True, bright images normalized to have a mean of 1 normalize_order: integer, optional Polynomial order for normalization. 0 means constant, 1 means linear, etc. - Higher orders not yet implemented. defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + aligned_bf_image_guess: np.ndarray, optional + Guess for the reference BF image to cross-correlate against during the first iteration + If None, the incoherent BF image is used instead. + rotation_guess: float, optional + Initial guess of rotation value in degrees + If None, first iteration assumed to be 0 descan_correction_fit_function: str, optional If not None, descan correction will be performed using fit function. One of "constant", "plane", "parabola", or "bezier_two". - rotation_guess: float, optional - Initial guess of defocus value in degrees - If None, first iteration assumed to be 0 plot_average_bf: bool, optional If True, plots the average bright field image, using defocus_guess realspace_mask: np.array, optional @@ -308,6 +314,8 @@ def preprocess( if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation store_initial_arrays: bool, optional If True, stores a copy of the arrays necessary to reinitialize in reconstruct @@ -474,7 +482,6 @@ def preprocess( self._stack_BF_unshifted = xp.ones(stack_shape, xp.float32) if normalize_order == 0: - # all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] weights = xp.average( all_bfs.reshape((self._num_bf_images, -1)), weights=self._window_edge.ravel(), @@ -517,7 +524,6 @@ def preprocess( weights = np.sqrt(self._window_edge).ravel() for a0 in range(all_bfs.shape[0]): - # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) # weighted least squares coefs = np.linalg.lstsq( weights[:, None] * basis, @@ -574,7 +580,6 @@ def preprocess( weights = np.sqrt(self._window_edge).ravel() for a0 in range(all_bfs.shape[0]): - # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) # weighted least squares coefs = np.linalg.lstsq( weights[:, None] * basis, @@ -645,49 +650,84 @@ def preprocess( # Initialization utilities self._stack_mask = xp.tile(self._window_pad[None], (self._num_bf_images, 1, 1)) + + if max_batch_size is None: + max_batch_size = self._num_bf_images + + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) + if defocus_guess is not None: - Gs = xp.fft.fft2(self._stack_BF_shifted) + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted[start:end] + probe_angles = self._probe_angles[start:end] + stack_mask = self._stack_mask[start:end] - self._xy_shifts = ( - -self._probe_angles - * defocus_guess - / xp.array(self._scan_sampling, dtype=xp.float32) - ) + Gs = xp.fft.fft2(shifted_BFs) - if rotation_guess: - angle = xp.deg2rad(rotation_guess) - rotation_matrix = xp.array( - [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]], - dtype=xp.float32, + xy_shifts = ( + -probe_angles + * defocus_guess + / xp.array(self._scan_sampling, dtype=xp.float32) ) - self._xy_shifts = xp.dot(self._xy_shifts, rotation_matrix) - dx = self._xy_shifts[:, 0] - dy = self._xy_shifts[:, 1] + if rotation_guess is not None: + angle = xp.deg2rad(rotation_guess) + rotation_matrix = xp.array( + [ + [np.cos(angle), np.sin(angle)], + [-np.sin(angle), np.cos(angle)], + ], + dtype=xp.float32, + ) + xy_shifts = xp.dot(xy_shifts, rotation_matrix) - shift_op = xp.exp( - self._qx_shift[None] * dx[:, None, None] - + self._qy_shift[None] * dy[:, None, None] - ) - self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) - self._stack_mask = xp.real( - xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) - ) + dx = xy_shifts[:, 0] + dy = xy_shifts[:, 1] - del Gs - else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + stack_mask = xp.real(xp.fft.ifft2(xp.fft.fft2(stack_mask) * shift_op)) + + self._xy_shifts[start:end] = xy_shifts + self._stack_BF_shifted[start:end] = stack_BF_shifted + self._stack_mask[start:end] = stack_mask + + del Gs self._stack_mean = xp.mean(self._stack_BF_shifted) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images - self._recon_mask = xp.sum(self._stack_mask, axis=0) + self._recon_mask = xp.mean(self._stack_mask, axis=0) mask_inv = 1 - xp.clip(self._recon_mask, 0, 1) - self._recon_BF = ( - self._stack_mean * mask_inv - + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) - ) / (self._recon_mask + mask_inv) + if aligned_bf_image_guess is not None: + aligned_bf_image_guess = xp.asarray(aligned_bf_image_guess) + if normalize_images: + self._recon_BF = xp.ones(stack_shape[-2:], dtype=xp.float32) + aligned_bf_image_guess /= aligned_bf_image_guess.mean() + else: + self._recon_BF = xp.full(stack_shape[-2:], self._stack_mean) + + self._recon_BF[ + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = ( + self._window_inv * self._stack_mean + + self._window_edge * aligned_bf_image_guess + ) + + else: + self._recon_BF = ( + self._stack_mean * mask_inv + + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) + ) / (self._recon_mask + mask_inv) self._recon_error = ( xp.atleast_1d( @@ -697,6 +737,7 @@ def preprocess( ) ) / self._mask_sum + / self._stack_mean ) if store_initial_arrays: @@ -737,12 +778,247 @@ def preprocess( return self + def guess_common_aberrations( + self, + rotation_angle_deg=0, + transpose=False, + kde_upsample_factor=None, + kde_sigma_px=0.125, + kde_lowpass_filter=False, + lanczos_interpolation_order=None, + defocus=0, + astigmatism=0, + astigmatism_angle_deg=0, + coma=0, + coma_angle_deg=0, + spherical_aberration=0, + max_batch_size=None, + plot_shifts_and_aligned_bf=True, + return_shifts_and_aligned_bf=False, + plot_arrow_freq=1, + scale_arrows=1, + **kwargs, + ): + """ + Generates analytical BF shifts and uses them to align the virtual BF stack, + based on the experimental geometry (rotation, transpose), and common aberrations. + + Parameters + ---------- + rotation_angle_deg: float, optional + Relative rotation between the scan and the diffraction space coordinate systems + transpose: bool, optional + Whether the diffraction intensities are transposed + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma_px: float, optional + KDE gaussian kernel bandwidth in non-upsampled pixels + kde_lowpass_filter: bool, optional + If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function + lanczos_interpolation_order: int, optional + If not None, Lanczos interpolation with the specified order is used instead of bilinear + defocus: float, optional + Defocus value to use in computing analytical BF shifts + astigmatism: float, optional + Astigmatism value to use in computing analytical BF shifts + astigmatism_angle_deg: float, optional + Astigmatism angle to use in computing analytical BF shifts + coma: float, optional + Coma value to use in computing analytical BF shifts + coma_angle_deg: float, optional + Coma angle to use in computing analytical BF shifts + spherical_aberration: float, optional + Spherical aberration value to use in computing analytical BF shifts + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation + plot_shifts_and_aligned_bf: bool, optional + If True, the analytical shifts and the aligned virtual VF image are plotted + return_shifts_and_aligned_bf: bool, optional + If True, the analytical shifts and the aligned virtual VF image are returned + plot_arrow_freq: int, optional + Frequency of shifts to plot in quiver plot + scale_arrows: float, optional + Scale to multiply shifts by + + """ + xp = self._xp + asnumpy = self._asnumpy + + if not hasattr(self, "_recon_BF"): + raise ValueError( + ( + "Aberration guessing is meant to be ran after preprocessing. " + "Please run the `preprocess()` function first." + ) + ) + + # aberrations_coefs + aberrations_mn = [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 1], + [2, 1, 0], + [2, 1, 1], + [3, 0, 0], + ] + astigmatism_x = astigmatism * np.cos(np.deg2rad(astigmatism_angle_deg) * 2) + astigmatism_y = astigmatism * np.sin(np.deg2rad(astigmatism_angle_deg) * 2) + coma_x = coma * np.cos(np.deg2rad(coma_angle_deg) * 1) + coma_y = coma * np.sin(np.deg2rad(coma_angle_deg) * 1) + aberrations_coefs = xp.array( + [ + -defocus, + astigmatism_x, + astigmatism_y, + coma_x, + coma_y, + spherical_aberration, + ] + ) + + # transpose rotation matrix + if transpose: + rotation_angle_deg *= -1 + + # aberrations_basis + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + aberrations_basis, aberrations_basis_du, aberrations_basis_dv = ( + calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=np.deg2rad(rotation_angle_deg), + xp=xp, + ) + ) + + # shifts + corner_indices = self._xy_inds - xp.array(self._region_of_interest_shape // 2) + raveled_indices = xp.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.array( + ( + aberrations_basis_du[raveled_indices, :], + aberrations_basis_dv[raveled_indices, :], + ) + ) + shifts_ang = xp.tensordot(gradients, aberrations_coefs, axes=1).T + + # transpose predicted shifts + if transpose: + shifts_ang = xp.flip(shifts_ang, axis=1) + + shifts_px = shifts_ang / xp.array(self._scan_sampling) + + # upsampled stack + if kde_upsample_factor is not None: + BF_size = np.array(self._stack_BF_unshifted.shape[-2:]) + pixel_output_shape = np.round(BF_size * kde_upsample_factor).astype("int") + + x = xp.arange(BF_size[0], dtype=xp.float32) + y = xp.arange(BF_size[1], dtype=xp.float32) + xa_init, ya_init = xp.meshgrid(x, y, indexing="ij") + + # kernel density output the upsampled BF image + xa = (xa_init + shifts_px[:, 0, None, None]) * kde_upsample_factor + ya = (ya_init + shifts_px[:, 1, None, None]) * kde_upsample_factor + + pix_output = self._kernel_density_estimate( + xa, + ya, + self._stack_BF_unshifted, + pixel_output_shape, + kde_sigma_px * kde_upsample_factor, + lanczos_alpha=lanczos_interpolation_order, + lowpass_filter=kde_lowpass_filter, + ) + + # hack since cropping requires "_kde_upsample_factor" + old_upsample_factor = getattr(self, "_kde_upsample_factor", None) + self._kde_upsample_factor = kde_upsample_factor + cropped_image = asnumpy( + self._crop_padded_object(pix_output, upsampled=True) + ) + if old_upsample_factor is not None: + self._kde_upsample_factor = old_upsample_factor + else: + del self._kde_upsample_factor + + # shifted stack + else: + kde_upsample_factor = 1 + aligned_stack = xp.zeros_like(self._stack_BF_shifted_initial[0]) + + if max_batch_size is None: + max_batch_size = self._num_bf_images + + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted_initial[start:end] + + Gs = xp.fft.fft2(shifted_BFs) + + dx = shifts_px[start:end, 0] + dy = shifts_px[start:end, 1] + + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + aligned_stack += stack_BF_shifted.sum(0) + + cropped_image = asnumpy( + self._crop_padded_object(aligned_stack, upsampled=False) + ) + + if plot_shifts_and_aligned_bf: + figsize = kwargs.pop("figsize", (8, 4)) + color = kwargs.pop("color", (1, 0, 0, 1)) + cmap = kwargs.pop("cmap", "magma") + + fig, axs = plt.subplots(1, 2, figsize=figsize) + + self.show_shifts( + shifts_ang=shifts_ang, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + plot_rotated_shifts=False, + color=color, + figax=(fig, axs[0]), + ) + + axs[0].set_title("Predicted BF Shifts") + + extent = [ + 0, + self._scan_sampling[1] * cropped_image.shape[1] / kde_upsample_factor, + self._scan_sampling[0] * cropped_image.shape[0] / kde_upsample_factor, + 0, + ] + + axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs) + axs[1].set_ylabel("x [A]") + axs[1].set_xlabel("y [A]") + axs[1].set_title("Predicted Aligned BF Image") + + fig.tight_layout() + + if return_shifts_and_aligned_bf: + return shifts_ang, cropped_image + def reconstruct( self, max_alignment_bin: int = None, min_alignment_bin: int = 1, num_iter_at_min_bin: int = 2, alignment_bin_values: list = None, + centered_alignment_bins: bool = True, cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = False, @@ -753,6 +1029,7 @@ def reconstruct( reset: bool = None, device: str = None, clear_fft_cache: bool = None, + max_batch_size: int = None, **kwargs, ): """ @@ -787,6 +1064,8 @@ def reconstruct( If True, the reconstruction is reset device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation clear_fft_cache: bool, optional if true, and device = 'gpu', clears the cached fft plan at the end of function calls @@ -879,6 +1158,8 @@ def reconstruct( (bin_vals, np.repeat(bin_vals[-1], num_iter_at_min_bin - 1)) ) + bin_shift = 0 if centered_alignment_bins else 0.5 + if plot_aligned_bf: num_plots = bin_vals.shape[0] nrows = int(np.sqrt(num_plots)) @@ -908,12 +1189,15 @@ def reconstruct( xy_center = (self._xy_inds - xp.median(self._xy_inds, axis=0)).astype("float") + if max_batch_size is None: + max_batch_size = self._num_bf_images + # Loop over all binning values for a0 in range(bin_vals.shape[0]): G_ref = xp.fft.fft2(self._recon_BF) # Segment the virtual images with current binning values - xy_inds = xp.round(xy_center / bin_vals[a0] + 0.5).astype("int") + xy_inds = xp.round(xy_center / bin_vals[a0] + bin_shift).astype("int") xy_vals = np.unique( asnumpy(xy_inds), axis=0 ) # axis is not yet supported in cupy @@ -978,31 +1262,33 @@ def reconstruct( shifts_update = xy_shifts_fit - self._xy_shifts # apply shifts - Gs = xp.fft.fft2(self._stack_BF_shifted) + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted[start:end] + stack_mask = self._stack_mask[start:end] - dx = shifts_update[:, 0] - dy = shifts_update[:, 1] - self._xy_shifts[:, 0] += dx - self._xy_shifts[:, 1] += dy + Gs = xp.fft.fft2(shifted_BFs) - shift_op = xp.exp( - self._qx_shift[None] * dx[:, None, None] - + self._qy_shift[None] * dy[:, None, None] - ) + dx = shifts_update[start:end, 0] + dy = shifts_update[start:end, 1] - self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) - self._stack_mask = xp.real( - xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) - ) + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) - self._stack_BF_shifted = xp.asarray( - self._stack_BF_shifted, dtype=xp.float32 - ) # numpy fft upcasts? - self._stack_mask = xp.asarray( - self._stack_mask, dtype=xp.float32 - ) # numpy fft upcasts? + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + stack_mask = xp.real(xp.fft.ifft2(xp.fft.fft2(stack_mask) * shift_op)) - del Gs + self._stack_BF_shifted[start:end] = xp.asarray( + stack_BF_shifted, dtype=xp.float32 + ) + self._stack_mask[start:end] = xp.asarray(stack_mask, dtype=xp.float32) + self._xy_shifts[start:end, 0] += dx + self._xy_shifts[start:end, 1] += dy + + del Gs # Center the shifts xy_shifts_median = xp.round(xp.median(self._xy_shifts, axis=0)).astype(int) @@ -1013,12 +1299,12 @@ def reconstruct( self._stack_mask = xp.roll(self._stack_mask, -xy_shifts_median, axis=(1, 2)) # Generate new estimate - self._recon_mask = xp.sum(self._stack_mask, axis=0) + self._recon_mask = xp.mean(self._stack_mask, axis=0) mask_inv = 1 - np.clip(self._recon_mask, 0, 1) self._recon_BF = ( self._stack_mean * mask_inv - + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) + + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) ) / (self._recon_mask + mask_inv) self._recon_error = ( @@ -1029,6 +1315,7 @@ def reconstruct( ) ) / self._mask_sum + / self._stack_mean ) self.error_iterations.append(float(self._recon_error)) @@ -2042,75 +2329,21 @@ def calculate_CTF_FFT(alpha_shape, *coefs): # Direct Shifts Fitting if fit_BF_shifts: - # FFT coordinates - sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) - sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) - qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) - qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) - qx, qy = np.meshgrid(qx, qy, indexing="ij") - - # passive rotation basis by -theta - rotation_angle = -self.rotation_Q_to_R_rads - qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( - rotation_angle - ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) - - qr2 = qx**2 + qy**2 - u = qx * self._wavelength - v = qy * self._wavelength - alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy, qx) - - # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) - self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) - self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) - for a0 in range(self._aberrations_num): - m, n, a = self._aberrations_mn[a0] - - if n == 0: - # Radially symmetric basis - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() - self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() - - elif a == 0: - # cos coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) - / (m + 1) - ).ravel() - self._aberrations_basis_dv[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) - / (m + 1) - ).ravel() - - else: - # sin coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) - / (m + 1) - ).ravel() - self._aberrations_basis_dv[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) - / (m + 1) - ).ravel() - - # global scaling - self._aberrations_basis *= 2 * np.pi / self._wavelength - self._aberrations_surface_shape = alpha.shape + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + ( + self._aberrations_babis, + self._aberrations_basis_du, + self._aberrations_basis_dv, + ) = calculate_aberration_gradient_basis( + self._aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=self.rotation_Q_to_R_rads, + xp=xp, + ) # CTF function def calculate_CTF(alpha_shape, *coefs): @@ -2199,19 +2432,6 @@ def score_CTF(coefs): # Plot the measured/fitted shifts comparison if plot_BF_shifts_comparison: - measured_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 0] - ) - - measured_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 1] - ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2219,53 +2439,28 @@ def score_CTF(coefs): .T ) - fitted_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 0] - ) + scale_arrows = kwargs.pop("scale_arrows", 1) + plot_arrow_freq = kwargs.pop("plot_arrow_freq", 1) + figsize = kwargs.pop("figsize", (4, 4)) - fitted_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 1] - ) + fig, ax = plt.subplots(figsize=figsize) - max_shift = xp.max( - xp.array( - [ - xp.abs(measured_shifts_sx).max(), - xp.abs(measured_shifts_sy).max(), - xp.abs(fitted_shifts_sx).max(), - xp.abs(fitted_shifts_sy).max(), - ] - ) + self.show_shifts( + shifts_ang=self._xy_shifts_Ang, + plot_rotated_shifts=False, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + color=(1, 0, 0, 0.5), + figax=(fig, ax), ) - axsize = kwargs.pop("axsize", (4, 4)) - cmap = kwargs.pop("cmap", "PiYG") - vmin = kwargs.pop("vmin", -max_shift) - vmax = kwargs.pop("vmax", max_shift) - - show( - [ - [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], - [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], - ], - cmap=cmap, - vmin=vmin, - vmax=vmax, - intensity_range="absolute", - axsize=axsize, - ticks=False, - title=[ - "Measured Vertical Shifts", - "Fitted Vertical Shifts", - "Measured Horizontal Shifts", - "Fitted Horizontal Shifts", - ], + self.show_shifts( + shifts_ang=fitted_shifts, + plot_rotated_shifts=False, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + color=(0, 0, 1, 0.5), + figax=(fig, ax), ) # Plot the CTF comparison between experiment and fit @@ -2823,9 +3018,11 @@ def _visualize_figax( def show_shifts( self, + shifts_ang=None, scale_arrows=1, plot_arrow_freq=1, plot_rotated_shifts=True, + figax=None, **kwargs, ): """ @@ -2833,31 +3030,58 @@ def show_shifts( Parameters ---------- + shifts_ang: np.ndarray, optional + If None, self._xy_shifts is used scale_arrows: float, optional Scale to multiply shifts by plot_arrow_freq: int, optional Frequency of shifts to plot in quiver plot + plot_rotated_shifts: bool, optional + If True, shifts are plotted with the relative rotation decomposed + figax: optional + Tuple of figure, axes to plot against """ xp = self._xp asnumpy = self._asnumpy color = kwargs.pop("color", (1, 0, 0, 1)) + + if shifts_ang is None: + shifts_px = self._xy_shifts + else: + shifts_px = shifts_ang / xp.array(self._scan_sampling) + + shifts = shifts_px * scale_arrows * xp.array(self._reciprocal_sampling) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): - figsize = kwargs.pop("figsize", (8, 4)) - fig, ax = plt.subplots(1, 2, figsize=figsize) - scaling_factor = ( - xp.array(self._reciprocal_sampling) - / xp.array(self._scan_sampling) - * scale_arrows + + if figax is None: + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + else: + fig, ax = figax + + rotated_color = kwargs.pop("rotated_color", (0, 0, 0, 1)) + + if shifts_ang is None: + rotated_shifts_px = self._xy_shifts.copy() + else: + rotated_shifts_px = shifts_ang / xp.array(self._scan_sampling) + + if self.transpose: + rotated_shifts_px = xp.flip(rotated_shifts_px, axis=1) + + rotated_shifts = ( + rotated_shifts_px * scale_arrows * xp.array(self._reciprocal_sampling) ) - rotated_shifts = self._xy_shifts_Ang * scaling_factor else: - figsize = kwargs.pop("figsize", (4, 4)) - fig, ax = plt.subplots(figsize=figsize) - - shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] + if figax is None: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + else: + fig, ax = figax dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -2900,6 +3124,7 @@ def show_shifts( angles="xy", scale_units="xy", scale=1, + color=rotated_color, **kwargs, ) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index fc1e84f11..886fd7972 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -135,7 +135,7 @@ def grid_search( Parameters ---------- - n_initial_points: int + n_points: int Number of uniformly spaced trial points to run on a grid error_metric: Callable or str Function used to compute the reconstruction error. @@ -233,7 +233,7 @@ def evaluation_callback(ptycho): ax.imshow(res[0], cmap=cmap) title_substrings = [ - f"{param.name}: {val}" + f"{param.name}: {val:.3e}" for param, val in zip(self._parameter_list, params) ] title_substrings.append(f"error: {res[1]:.3e}") @@ -458,13 +458,15 @@ def _split_static_and_optimization_vars(self, argdict): return static_args, optimization_args def _get_scan_positions(self, affine_transform, dataset): - R_pixel_size = dataset.calibration.get_R_pixel_size() - x, y = ( - np.arange(dataset.R_Nx) * R_pixel_size, - np.arange(dataset.R_Ny) * R_pixel_size, - ) - x, y = np.meshgrid(x, y, indexing="ij") - scan_positions = np.stack((x.ravel(), y.ravel()), axis=1) + scan_positions = self._init_static_args.get("initial_scan_positions", None) + if scan_positions is None: + R_pixel_size = dataset.calibration.get_R_pixel_size() + x, y = ( + np.arange(dataset.R_Nx) * R_pixel_size, + np.arange(dataset.R_Ny) * R_pixel_size, + ) + x, y = np.meshgrid(x, y, indexing="ij") + scan_positions = np.stack((x.ravel(), y.ravel()), axis=1) scan_positions = scan_positions @ affine_transform.asarray() return scan_positions @@ -483,8 +485,10 @@ def _get_error_metric(self, error_metric: Union[Callable, str]) -> Callable: "log-converged", "linear-converged", "TV", + "TV-phase", "std", "std-phase", + "entropy", "entropy-phase", ), f"Error metric {error_metric} not recognized." @@ -517,10 +521,20 @@ def f(ptycho): elif error_metric == "TV": def f(ptycho): - gx, gy = np.gradient(ptycho.object_cropped, axis=(-2, -1)) - obj_mag = np.sum(np.abs(ptycho.object_cropped)) + array = np.abs(ptycho.object_cropped) + gx = array[..., 1:, :] - array[..., -1:, :] + gy = array[..., :, 1:] - array[..., :, -1:] + tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy)) + return tv / array.size + + elif error_metric == "TV-phase": + + def f(ptycho): + array = np.angle(ptycho.object_cropped) + gx = array[..., 1:, :] - array[..., -1:, :] + gy = array[..., :, 1:] - array[..., :, -1:] tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy)) - return tv / obj_mag + return tv / array.size elif error_metric == "std": @@ -532,16 +546,30 @@ def f(ptycho): def f(ptycho): return -np.std(np.angle(ptycho.object_cropped)) + elif error_metric == "entropy": + + def f(ptycho): + array = np.abs(ptycho.object_cropped) + normalized_array = (array - np.min(array)) / np.ptp(array) + # gx = normalized_array[..., 1:, :] - normalized_array[..., -1:, :] + # gy = normalized_array[..., :, 1:] - normalized_array[..., :, -1:] + gx, gy = np.gradient(normalized_array, axis=(-2, -1)) + ghist, _, _ = np.histogram2d(gx.ravel(), gy.ravel(), bins=array.shape) + ghist = ghist[ghist > 0] / array.size + S = np.sum(ghist * np.log2(ghist)) + return S + elif error_metric == "entropy-phase": def f(ptycho): - obj = np.angle(ptycho.object_cropped) - gx, gy = np.gradient(obj) - ghist, _, _ = np.histogram2d( - gx.ravel(), gy.ravel(), bins=obj.shape, density=True - ) - nz = ghist > 0 - S = np.sum(ghist[nz] * np.log2(ghist[nz])) + array = np.angle(ptycho.object_cropped) + normalized_array = (array - np.min(array)) / np.ptp(array) + # gx = normalized_array[..., 1:, :] - normalized_array[..., -1:, :] + # gy = normalized_array[..., :, 1:] - normalized_array[..., :, -1:] + gx, gy = np.gradient(normalized_array, axis=(-2, -1)) + ghist, _, _ = np.histogram2d(gx.ravel(), gy.ravel(), bins=array.shape) + ghist = ghist[ghist > 0] / array.size + S = np.sum(ghist * np.log2(ghist)) return S else: diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 22fda4ac9..c571dcd3d 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -686,6 +686,7 @@ def _calculate_intensities_center_of_mass( # calculate CoM if dp_mask is not None: + dp_mask = copy_to_device(dp_mask, device) intensities_mask = intensities * dp_mask else: intensities_mask = intensities @@ -1350,6 +1351,7 @@ def _normalize_diffraction_intensities( com_fitted_y, positions_mask, crop_patterns, + in_place_datacube_modification, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1362,78 +1364,67 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient - positions_mask: np.ndarray, optional + positions_mask: np.ndarray Boolean real space mask to select positions in datacube to skip for reconstruction crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns - when centering + If True, patterns are cropped to avoid wrap around of patterns + in_place_datacube_modification: bool + If True, the diffraction intensities are modified in-place Returns ------- - amplitudes: (Rx * Ry, Sx, Sy) np.ndarray + diffraction_intensities: (Rx * Ry, Sx, Sy) np.ndarray Flat array of normalized diffraction amplitudes mean_intensity: float Mean intensity value + crop_mask + Mask to crop diffraction patterns with """ # explicit read-only self attributes up-front asnumpy = self._asnumpy mean_intensity = 0 - - diffraction_intensities = asnumpy(diffraction_intensities) com_fitted_x = asnumpy(com_fitted_x) com_fitted_y = asnumpy(com_fitted_y) - if positions_mask is not None: - number_of_patterns = np.count_nonzero(positions_mask.ravel()) + if in_place_datacube_modification: + diff_intensities = diffraction_intensities else: - number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + diff_intensities = diffraction_intensities.copy() # Aggressive cropping for when off-centered high scattering angle data was recorded if crop_patterns: crop_x = int( np.minimum( - diffraction_intensities.shape[2] - com_fitted_x.max(), + diff_intensities.shape[2] - com_fitted_x.max(), com_fitted_x.min(), ) ) crop_y = int( np.minimum( - diffraction_intensities.shape[3] - com_fitted_y.max(), + diff_intensities.shape[3] - com_fitted_y.max(), com_fitted_y.min(), ) ) crop_w = np.minimum(crop_y, crop_x) - diffraction_intensities_shape_crop = (crop_w * 2, crop_w * 2) - amplitudes = np.zeros( - ( - number_of_patterns, - crop_w * 2, - crop_w * 2, - ), - dtype=np.float32, - ) - crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask = np.zeros(diff_intensities.shape[-2:], dtype="bool") crop_mask[:crop_w, :crop_w] = True crop_mask[-crop_w:, :crop_w] = True crop_mask[:crop_w:, -crop_w:] = True crop_mask[-crop_w:, -crop_w:] = True + crop_mask_shape = (crop_w * 2, crop_w * 2) + else: crop_mask = None - diffraction_intensities_shape_crop = diffraction_intensities.shape[-2:] - amplitudes = np.zeros( - (number_of_patterns,) + diffraction_intensities_shape_crop, - dtype=np.float32, - ) + crop_mask_shape = diff_intensities.shape[-2:] - counter = 0 for rx, ry in tqdmnd( - diffraction_intensities.shape[0], - diffraction_intensities.shape[1], + diff_intensities.shape[0], + diff_intensities.shape[1], desc="Normalizing amplitudes", unit="probe position", disable=not self._verbose, @@ -1441,28 +1432,32 @@ def _normalize_diffraction_intensities( if positions_mask is not None: if not positions_mask[rx, ry]: continue + intensities = get_shifted_ar( - diffraction_intensities[rx, ry], + diff_intensities[rx, ry], -com_fitted_x[rx, ry], -com_fitted_y[rx, ry], bilinear=True, device="cpu", ) - if crop_patterns: - intensities = intensities[crop_mask].reshape( - diffraction_intensities_shape_crop - ) - mean_intensity += np.sum(intensities) - amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) - counter += 1 + diff_intensities[rx, ry] = np.sqrt(np.maximum(intensities, 0)) - mean_intensity /= amplitudes.shape[0] + if positions_mask is not None: + diff_intensities = diff_intensities[positions_mask] + else: + qx, qy = diff_intensities.shape[-2:] + diff_intensities = diff_intensities.reshape((-1, qx, qy)) - self._diffraction_intensities_shape_crop = diffraction_intensities_shape_crop + if crop_patterns: + diff_intensities = diff_intensities[:, crop_mask].reshape( + (-1,) + crop_mask_shape + ) - return amplitudes, mean_intensity, crop_mask + mean_intensity /= diff_intensities.shape[0] + + return diff_intensities, mean_intensity, crop_mask, crop_mask_shape def show_complex_CoM( self, @@ -1556,6 +1551,7 @@ def to_h5(self, group): "semiangle_cutoff": self._semiangle_cutoff, "rolloff": self._rolloff, "object_padding_px": self._object_padding_px, + "object_fov_ang": self._object_fov_ang, "object_type": self._object_type, "verbose": self._verbose, "device": self._device, @@ -1864,42 +1860,34 @@ def _calculate_scan_positions_in_pixels( else: raise ValueError() - if transpose: - x = (x - np.ptp(x) / 2) / sampling[1] - y = (y - np.ptp(y) / 2) / sampling[0] - else: - x = (x - np.ptp(x) / 2) / sampling[0] - y = (y - np.ptp(y) / 2) / sampling[1] x, y = np.meshgrid(x, y, indexing="ij") if positions_offset_ang is not None: - if transpose: - x += positions_offset_ang[0] / sampling[1] - y += positions_offset_ang[1] / sampling[0] - else: - x += positions_offset_ang[0] / sampling[0] - y += positions_offset_ang[1] / sampling[1] + x += positions_offset_ang[0] + y += positions_offset_ang[1] if positions_mask is not None: x = x[positions_mask] y = y[positions_mask] - else: - positions -= np.mean(positions, axis=0) - x = positions[:, 0] / sampling[1] - y = positions[:, 1] / sampling[0] + + positions = np.stack((x.ravel(), y.ravel()), axis=-1) if rotation_angle is not None: - x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin( - rotation_angle - ) + y * np.cos(rotation_angle) + tf = AffineTransform(angle=rotation_angle) + positions = tf(positions, positions.mean(0)) if transpose: - positions = np.array([y.ravel(), x.ravel()]).T - else: - positions = np.array([x.ravel(), y.ravel()]).T + positions = np.flip(positions, 1) + sampling = sampling[::-1] + + # ensure positive + positions -= np.min(positions, axis=0).clip(-np.inf, 0) - positions -= np.min(positions, axis=0) + # finally, switch to pixels + positions[:, 0] /= sampling[0] + positions[:, 1] /= sampling[1] + # top-left padding if object_padding_px is None: float_padding = region_of_interest_shape / 2 object_padding_px = (float_padding, float_padding) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 283ddb1ba..2e47a5e23 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -48,14 +48,24 @@ def _initialize_object( xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if object_type == "potential": _object = xp.zeros((p, q), dtype=xp.float32) elif object_type == "complex": @@ -402,14 +412,24 @@ def _initialize_object( xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if object_type == "potential": _object = xp.zeros((num_slices, p, q), dtype=xp.float32) elif object_type == "complex": @@ -858,14 +878,24 @@ def _initialize_object( # explicit read-only self attributes up-front xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if main_tilt_axis == "vertical": _object = xp.zeros((q, p, q), dtype=xp.float32) @@ -992,6 +1022,7 @@ def _initialize_probe( device = self._device crop_mask = self._crop_mask + crop_mask_shape = self._crop_mask_shape region_of_interest_shape = self._region_of_interest_shape sampling = self.sampling energy = self._energy @@ -1019,7 +1050,7 @@ def _initialize_probe( if crop_patterns: vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( - self._diffraction_intensities_shape_crop + crop_mask_shape ) sx, sy = vacuum_probe_intensity.shape diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 3639096dc..037ef4849 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -219,6 +219,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_probe_overlaps: bool = True, rotation_real_space_degrees: float = None, @@ -261,6 +262,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_probe_overlaps: bool, optional @@ -478,6 +482,7 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index cc7a5865d..d391dd293 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -88,6 +88,9 @@ class SingleslicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Ã…. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A verbose: bool, optional @@ -124,6 +127,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", @@ -177,6 +181,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -190,6 +195,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -206,6 +212,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -230,6 +237,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -266,6 +276,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -397,17 +409,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) @@ -443,15 +459,18 @@ def preprocess( self._object_shape = self._object.shape - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 5742ff7e7..bb960da62 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -22,7 +22,6 @@ def get_array_module(*args): from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from skimage.restoration import unwrap_phase # fmt: off @@ -1505,6 +1504,85 @@ def step_model(radius, sig_0, rad_0, width): return probe_corr, polar_int, polar_int_corr, coefs_all +def calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + gpts, + wavelength, + rotation_angle=0, + xp=np, +): + """ """ + sx, sy = sampling + nx, ny = gpts + qx = xp.fft.fftfreq(nx, sx) + qy = xp.fft.fftfreq(ny, sy) + qx, qy = xp.meshgrid(qx, qy, indexing="ij") + + # passive rotation + qx, qy = qx * xp.cos(-rotation_angle) + qy * xp.sin(-rotation_angle), -qx * xp.sin( + -rotation_angle + ) + qy * xp.cos(-rotation_angle) + + # coordinate system + qr2 = qx**2 + qy**2 + u = qx * wavelength + v = qy * wavelength + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy, qx) + + _aberrations_n = len(aberrations_mn) + _aberrations_basis = xp.zeros((alpha.size, _aberrations_n)) + _aberrations_basis_du = xp.zeros((alpha.size, _aberrations_n)) + _aberrations_basis_dv = xp.zeros((alpha.size, _aberrations_n)) + + for a0 in range(_aberrations_n): + m, n, a = aberrations_mn[a0] + + if n == 0: + # Radially symmetric basis + _aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + _aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + _aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() + + elif a == 0: + # cos coef + _aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + _aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + _aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() + + else: + # sin coef + _aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + _aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + _aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() + + # global scaling + _aberrations_basis *= 2 * np.pi / wavelength + + return _aberrations_basis, _aberrations_basis_du, _aberrations_basis_dv + + def aberrations_basis_function( probe_size, probe_sampling, @@ -1755,6 +1833,8 @@ def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np def unwrap_phase_2d_skimage(array, corner_centered=True, xp=np): + from skimage.restoration import unwrap_phase + if xp is np: array = array.astype(np.float64) unwrapped_array = unwrap_phase(array, wrap_around=corner_centered).astype( diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py new file mode 100644 index 000000000..b1b8a5862 --- /dev/null +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -0,0 +1,1969 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely x-ray magnetic ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class XRayMagneticPtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Iterative X-Ray Magnetic Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed electrostatic dimensions : (Px,Py) + Reconstructed magnetic dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + datacube: Sequence[DataCube] + Tuple of input 4D diffraction pattern intensities + energy: float + The electron energy of the wave functions in eV + magnetic_contribution_sign: str, optional + One of '-+', '-0+', '0+' + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Ã… and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad objects with + If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (2,Px,Py) + If None, initialized to 1.0j for complex objects and 0.0 for potential objects + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Ã… for each diffraction intensity + If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Ã…. If None, the fov is initialized using the + probe positions and object_padding_px + positions_offset_ang: np.ndarray, optional + Offset of positions in A + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_magnetic_contribution_sign",) + + def __init__( + self, + energy: float, + datacube: Sequence[DataCube] = None, + magnetic_contribution_sign: str = "-+", + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, + positions_offset_ang: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "xray_magnetic_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "complex": + raise NotImplementedError() + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._magnetic_contribution_sign = magnetic_contribution_sign + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, + fit_function: str = "plane", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: Sequence[np.ndarray] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + progress_bar: bool = True, + object_fov_mask: np.ndarray = True, + crop_patterns: bool = False, + center_positions_in_fov: bool = True, + store_initial_arrays: bool = True, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._magnetic_contribution_sign == "-+": + self._recon_mode = 0 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic contribution sign in first meaurement assumed to be negative.\n" + "Magnetic contribution sign in second meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "-0+": + self._recon_mode = 1 + self._num_measurements = 3 + magnetic_contribution_msg = ( + "Magnetic contribution sign in first meaurement assumed to be negative.\n" + "Magnetic contribution assumed to be zero in second meaurement.\n" + "Magnetic contribution sign in third meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "0+": + self._recon_mode = 2 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic contribution assumed to be zero in first meaurement.\n" + "Magnetic contribution sign in second meaurement assumed to be positive." + ) + else: + raise ValueError( + f"magnetic_contribution_sign must be either '-+', '-0+', or '0+', not {self._magnetic_contribution_sign}" + ) + + if self._verbose: + warnings.warn( + magnetic_contribution_msg, + UserWarning, + ) + + if len(self._datacube) != self._num_measurements: + raise ValueError( + f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if force_com_measured is None: + force_com_measured = [None] * self._num_measurements + + if self._scan_positions is None: + self._scan_positions = [None] * self._num_measurements + + if self._positions_offset_ang is None: + self._positions_offset_ang = [None] * self._num_measurements + + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + + if progress_bar: + # turn off verbosity to play nice with tqdm + verbose = self._verbose + self._verbose = False + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="measurement", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first measurement + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured[index], + ) + + # estimate rotation / transpose using first measurement + if index == 0: + # silence warnings to play nice with progress bar + verbose = self._verbose + self._verbose = False + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + _com_x, + _com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + com_measured_x, + com_measured_y, + com_normalized_x, + com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=False, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + self._verbose = verbose + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + self._crop_mask_shape, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + in_place_datacube_modification, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + self._positions_offset_ang[index], + ) + + if progress_bar: + # reset verbosity + self._verbose = verbose + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + ) + + if self._object is None: + # complex zeros instead of ones, since we store pre-exponential terms + self._object = xp.zeros((2,) + obj.shape, dtype=obj.dtype) + else: + self._object = obj + + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + if center_positions_in_fov: + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= ( + positions_px_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + # initialize object_fov_mask + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + probe_overlap, + extent=extent, + cmap="gray", + ) + ax2.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_xlim((extent[0], extent[1])) + ax2.set_ylim((extent[2], extent[3])) + ax2.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = xp.empty( + (self._num_measurements,) + shifted_probes.shape, dtype=current_object.dtype + ) + object_patches[0] = current_object[ + 0, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + object_patches[1] = current_object[ + 1, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + overlap_base = shifted_probes * xp.exp(1.0j * object_patches[0]) + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + overlap = overlap_base * xp.exp(-1.0j * object_patches[1]) + case (0, 1) | (1, 2) | (2, 1): # forward + overlap = overlap_base * xp.exp(1.0j * object_patches[1]) + case (1, 1) | (2, 0): # neutral + overlap = overlap_base + case _: + raise ValueError() + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_conj = xp.conj(shifted_probes) # P* + electrostatic_conj = xp.exp(-1.0j * xp.conj(object_patches[0])) # exp[-i c] + + probe_electrostatic_abs = xp.abs(probe_conj * electrostatic_conj) + probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( + probe_electrostatic_abs**2, + positions_px, + ) + probe_electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 + + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 + ) + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + + magnetic_conj = xp.exp(1.0j * xp.conj(object_patches[1])) + + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + + # - i * exp(i m*) * exp(-i c*) * P + electrostatic_update = self._sum_overlapping_patches_bincounts( + -1.0j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves, + positions_px, + ) + + # i * exp(i m*) * exp(-i c*) * P + magnetic_update = self._sum_overlapping_patches_bincounts( + 1.0j * magnetic_conj * electrostatic_conj * probe_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + + electrostatic_magnetic_abs = xp.abs( + electrostatic_conj * magnetic_conj + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ( + (1 - normalization_min) + * electrostatic_magnetic_normalization + ) + ** 2 + + ( + normalization_min + * xp.max(electrostatic_magnetic_normalization) + ) + ** 2 + ) + + # exp(i m*) * exp(-i c*) + current_probe += step_size * ( + xp.sum( + magnetic_conj * electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (0, 1) | (1, 2) | (2, 1): # forward + + magnetic_conj = xp.exp(-1.0j * xp.conj(object_patches[1])) + + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + + # - i * exp(-i m*) * exp(-i c*) * P + update = self._sum_overlapping_patches_bincounts( + -1.0j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves, + positions_px, + ) + + current_object[0] += step_size * update * probe_magnetic_normalization + current_object[1] += ( + step_size * update * probe_electrostatic_normalization + ) + + if not fix_probe: + + electrostatic_magnetic_abs = xp.abs( + electrostatic_conj * magnetic_conj + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ( + (1 - normalization_min) + * electrostatic_magnetic_normalization + ) + ** 2 + + ( + normalization_min + * xp.max(electrostatic_magnetic_normalization) + ) + ** 2 + ) + + # exp(i m*) * exp(-i c*) + current_probe += step_size * ( + xp.sum( + magnetic_conj * electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (1, 1) | (2, 0): # neutral + + probe_abs = xp.abs(shifted_probes) + probe_normalization = self._sum_overlapping_patches_bincounts( + probe_abs**2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + # -i exp(-i c*) * P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + -1.0j * electrostatic_conj * probe_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_normalization + ) + + if not fix_probe: + + electrostatic_abs = xp.abs(electrostatic_conj) + electrostatic_normalization = xp.sum( + electrostatic_abs**2, + axis=0, + ) + electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + ) + + # exp(-i c*) + current_probe += step_size * ( + xp.sum( + electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_normalization + ) + + case _: + raise ValueError() + + return current_object, current_probe + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma_e, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_lowpass_m, + q_highpass_e, + q_highpass_m, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + **kwargs, + ): + """MagneticObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object[0] = self._object_gaussian_constraint( + current_object[0], gaussian_filter_sigma_e, False + ) + current_object[1] = self._object_gaussian_constraint( + current_object[1], gaussian_filter_sigma_m, False + ) + if butterworth_filter: + current_object[0] = self._object_butterworth_constraint( + current_object[0], + q_lowpass_e, + q_highpass_e, + butterworth_order, + ) + current_object[1] = self._object_butterworth_constraint( + current_object[1], + q_lowpass_m, + q_highpass_m, + butterworth_order, + ) + if tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], tv_denoise_weight, tv_denoise_inner_iter + ) + + return current_object + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma_e: float = None, + gaussian_filter_sigma_m: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass_e: float = None, + q_lowpass_m: float = None, + q_highpass_e: float = None, + q_highpass_m: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + detector_fourier_mask: np.ndarray = None, + store_iterations: bool = False, + collective_measurement_updates: bool = True, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma_e: float + Standard deviation of gaussian kernel for electrostatic object in A + gaussian_filter_sigma_m: float + Standard deviation of gaussian kernel for magnetic object in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass_e: float + Cut-off frequency in A^-1 for low-pass filtering electrostatic object + q_lowpass_m: float + Cut-off frequency in A^-1 for low-pass filtering magnetic object + q_highpass_e: float + Cut-off frequency in A^-1 for high-pass filtering electrostatic object + q_highpass_m: float + Cut-off frequency in A^-1 for high-pass filtering magnetic object + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + collective_measurement_updates: bool + if True perform collective updates for all measurements + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if use_projection_scheme: + raise NotImplementedError( + "Magnetic ptychography is currently only implemented for gradient descent." + ) + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if object_type is not None: + self._switch_object_type(object_type) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is not None: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + if gaussian_filter_sigma_m is None: + gaussian_filter_sigma_m = gaussian_filter_sigma_e + + if q_lowpass_m is None: + q_lowpass_m = q_lowpass_e + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + # randomize + measurement_indices = np.arange(self._num_measurements) + np.random.shuffle(measurement_indices) + + for measurement_index in measurement_indices: + self._active_measurement_index = measurement_index + + measurement_error = 0.0 + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme=use_projection_scheme, + projection_a=projection_a, + projection_b=projection_b, + projection_c=projection_c, + ) + + # adjoint operator + object_update, _probe = self._adjoint( + self._object.copy(), + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + object_update -= self._object + + # position correction + if not fix_positions and a0 > 0: + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + ) + + measurement_error += batch_error + + if collective_measurement_updates: + collective_object += object_update + else: + self._object += object_update + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (12, 8)) + cmap_e_real = kwargs.pop("cmap_e_real", "cividis") + cmap_e_imag = kwargs.pop("cmap_e_imag", "magma") + cmap_m_real = kwargs.pop("cmap_m_real", "PuOr") + cmap_m_imag = kwargs.pop("cmap_m_imag", "PiYG") + chroma_boost = kwargs.pop("chroma_boost", 1) + + # get scaled arrays + obj = self.object_cropped + + vmin_e_real = kwargs.pop("vmin_e_real", None) + vmax_e_real = kwargs.pop("vmax_e_real", None) + vmin_e_imag = kwargs.pop("vmin_e_imag", None) + vmax_e_imag = kwargs.pop("vmax_e_imag", None) + _, vmin_e_real, vmax_e_real = return_scaled_histogram_ordering( + obj[0].real, vmin_e_real, vmax_e_real + ) + _, vmin_e_imag, vmax_e_iamg = return_scaled_histogram_ordering( + obj[0].imag, vmin_e_imag, vmax_e_imag + ) + + _, _, _vmax_m_real = return_scaled_histogram_ordering(obj[1].real) + vmin_m_real = kwargs.pop("vmin_m_real", -_vmax_m_real) + vmax_m_real = kwargs.pop("vmax_m_real", _vmax_m_real) + + _, _, _vmax_m_imag = return_scaled_histogram_ordering(obj[1].imag) + vmin_m_imag = kwargs.pop("vmin_m_imag", -_vmax_m_imag) + vmax_m_imag = kwargs.pop("vmax_m_imag", _vmax_m_imag) + + extent = [ + 0, + self.sampling[1] * obj.shape[2], + self.sampling[0] * obj.shape[1], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe: + spec = GridSpec( + ncols=3, + nrows=3, + height_ratios=[4, 4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=3, height_ratios=[4, 4, 1], hspace=0.15) + + else: + if plot_probe: + spec = GridSpec( + ncols=3, + nrows=2, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=2, wspace=0.35) + + if fig is None: + fig = plt.figure(figsize=figsize) + + # Electronic real + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0].real, + extent=extent, + cmap=cmap_e_real, + vmin=vmin_e_real, + vmax=vmax_e_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real elec. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Electronic imag + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[0].imag, + extent=extent, + cmap=cmap_e_imag, + vmin=vmin_e_imag, + vmax=vmax_e_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imag elec. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Magnetic real + ax = fig.add_subplot(spec[1, 0]) + im = ax.imshow( + obj[1].real, + extent=extent, + cmap=cmap_m_real, + vmin=vmin_m_real, + vmax=vmax_m_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real mag. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Magnetic imag + ax = fig.add_subplot(spec[1, 1]) + im = ax.imshow( + obj[1].imag, + extent=extent, + cmap=cmap_m_imag, + vmin=vmin_m_imag, + vmax=vmax_m_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imag mag. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_fourier_probe: + # Fourier probe + intensities = self._return_probe_intensities(None) + titles = [ + f"{sign}ve Fourier probe: {ratio*100:.1f}%" + for sign, ratio in zip(self._magnetic_contribution_sign, intensities) + ] + ax = fig.add_subplot(spec[0, 2]) + + probe_fourier = asnumpy( + self._return_fourier_probe( + self._probes_all[0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe_fourier, + chroma_boost=chroma_boost, + ) + + ax.set_title(titles[0]) + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + ax = fig.add_subplot(spec[1, 2]) + + probe_fourier = asnumpy( + self._return_fourier_probe( + self._probes_all[-1], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe_fourier, + chroma_boost=chroma_boost, + ) + + ax.set_title(titles[-1]) + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + elif plot_probe: + # Real probe + intensities = self._return_probe_intensities(None) + titles = [ + f"{sign}ve probe intensity: {ratio*100:.1f}%" + for sign, ratio in zip(self._magnetic_contribution_sign, intensities) + ] + ax = fig.add_subplot(spec[0, 2]) + + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(self._probes_all[0])), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(titles[0]) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + ax = fig.add_subplot(spec[1, 2]) + + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(self._probes_all[-1])), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(titles[-1]) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[2, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + @property + def object_cropped(self): + """Cropped and rotated object""" + avg_pos = self._return_average_positions() + cropped_e = self._crop_rotate_object_fov(self._object[0], positions_px=avg_pos) + cropped_m = self._crop_rotate_object_fov(self._object[1], positions_px=avg_pos) + + return np.array([cropped_e, cropped_m]) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 1366b5ecf..0104e7ba8 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -983,6 +983,8 @@ def background_pca( radial PCA component selected """ + from sklearn.decomposition import PCA + # PCA decomposition shape = self.radial_all.shape A = np.reshape(self.radial_all, (shape[0] * shape[1], shape[2])) diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index 7760726a7..34630dbc2 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -16,6 +16,9 @@ def fit_amorphous_ring( show_fit_mask=False, fit_all_images=False, maxfev=None, + robust=False, + robust_steps=3, + robust_thresh=1.0, verbose=False, plot_result=True, plot_log_scale=False, @@ -50,6 +53,13 @@ def fit_amorphous_ring( Fit the elliptic parameters to all images maxfev: int Max number of fitting evaluations for curve_fit. + robust: bool + Set to True to use robust fitting. + robust_steps: int + Number of robust fitting steps. + robust_thresh: float + Threshold for relative errors for outlier detection. Setting to 1.0 means all points beyond + one standard deviation of the median error will be excluded from the next fit. verbose: bool Print fit results plot_result: bool @@ -206,6 +216,38 @@ def fit_amorphous_ring( maxfev=maxfev, )[0] coefs[4] = np.mod(coefs[4], 2 * np.pi) + + if robust: + for a0 in range(robust_steps): + # find outliers + int_fit = amorphous_model(basis, *coefs) + int_diff = vals / int_mean - int_fit + int_diff /= np.median(np.abs(int_diff)) + sub_fit = int_diff**2 < robust_thresh**2 + + # redo fits excluding the outliers + if maxfev is None: + coefs = curve_fit( + amorphous_model, + basis[:, sub_fit], + vals[sub_fit] / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + )[0] + else: + coefs = curve_fit( + amorphous_model, + basis[:, sub_fit], + vals[sub_fit] / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + maxfev=maxfev, + )[0] + coefs[4] = np.mod(coefs[4], 2 * np.pi) + + # Scale intensity coefficients coefs[5:8] *= int_mean # Perform the fit on each individual diffration pattern diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index cb89c58ad..fdd0a6a2e 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -3,7 +3,6 @@ from scipy.ndimage import gaussian_filter, gaussian_filter1d from scipy.signal import peak_prominences -from skimage.feature import peak_local_max from scipy.optimize import curve_fit, leastsq import warnings @@ -114,6 +113,8 @@ def find_peaks_single_pattern( """ + from skimage.feature import peak_local_max + # if needed, generate mask from Bragg peaks if bragg_peaks is not None: mask_bragg = self._datacube.get_braggmask( @@ -705,11 +706,28 @@ def plot_radial_peaks( qstep=None, label_y_axis=False, figsize=(8, 4), + v_lines=None, returnfig=False, ): """ Calculate and plot the total peak signal as a function of the radial coordinate. + q_pixel_units + If True, plot in reciprocal units instead of pixels. + qmin + The minimum q for plotting. + qmax + The maximum q for plotting. + qstep + The bin width. + label_y_axis + If True, label y axis. + figsize + Plot size. + v_lines: tuple + x coordinates for plotting vertical lines. + returnfig + If True, returns figure. """ # Get all peak data @@ -795,6 +813,15 @@ def plot_radial_peaks( # bottom=True, # ) + if v_lines is not None: + y_min, y_max = ax.get_ylim() + + if np.isscalar(v_lines): + ax.vlines(v_lines, y_min, y_max, color="g") + else: + for a0 in range(len(v_lines)): + ax.vlines(v_lines[a0], y_min, y_max, color="g") + if returnfig: return fig, ax diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 54443b68f..90397560f 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -57,8 +57,6 @@ def projected_potential(self, Z, R): me = 9.10938356e-31 # Electron charge in Coulomb qe = 1.60217662e-19 - # Electron charge in V-Angstroms - # qe = 14.4 # Permittivity of vacuum eps_0 = 8.85418782e-12 # Bohr's constant @@ -66,21 +64,16 @@ def projected_potential(self, Z, R): fe = np.zeros_like(R) for i in range(5): + # fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 pre = 2 * np.pi / bi[i] ** 0.5 fe += (ai[i] / bi[i] ** 1.5) * (kn(0, pre * R) + R * kn(1, pre * R)) - # Scale output units - # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*qe) - # fe *= 2*np.pi**2 / kappa - # # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) - - # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) - # # return fe * 2 * np.pi**2 # / kappa - # # if units == "VA": - # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe - # # elif units == "A": - # # return fe * 2 * np.pi**2 / kappa - return fe + # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + return fe * 2 * np.pi**2 # / kappa + # if units == "VA": + # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe + # elif units == "A": + # return fe * 2 * np.pi**2 / kappa def get_scattering_factor( self, elements=None, composition=None, q_coords=None, units=None diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index b50a21de2..2d292618e 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -5,13 +5,19 @@ import re from importlib.util import find_spec +from py4DSTEM import is_package_lite + # need a mapping of pypi/conda names to import names -import_mapping_dict = { - "scikit-image": "skimage", - "scikit-learn": "sklearn", - "scikit-optimize": "skopt", - "mp-api": "mp_api", -} +import_mapping_dict = ( + {} + if is_package_lite + else { + "scikit-image": "skimage", + "scikit-learn": "sklearn", + "scikit-optimize": "skopt", + "mp-api": "mp_api", + } +) # programatically get all possible requirements in the import name style @@ -88,7 +94,8 @@ def get_modules_dict(): # module_depenencies = get_modules_dict() -modules = get_modules_list() +# modules = get_modules_list() +modules = [] if is_package_lite else get_modules_list() #### Class and Functions to Create Coloured Strings #### @@ -526,9 +533,13 @@ def print_no_extra_checks(m: str): # dict of extra check functions -funcs_dict = { - "cupy": check_cupy_gpu, -} +funcs_dict = ( + {} + if is_package_lite + else { + "cupy": check_cupy_gpu, + } +) #### main function used to check the configuration of the installation diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index bbe70be03..b9decac47 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.15" +__version__ = "0.14.17" diff --git a/setup.py b/setup.py index 3a853cc9d..28b16692a 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ author="Benjamin H. Savitzky", author_email="ben.savitzky@gmail.com", license="GNU GPLv3", - keywords="STEM 4DSTEM", + keywords="STEM,4DSTEM", python_requires=">=3.10", install_requires=[ "numpy >= 1.19",