From 3905e8d12c633290e9d552116ca24880cf1175c2 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 7 Mar 2024 21:37:10 -0800 Subject: [PATCH 01/74] Adding robust fitting to ACOM strain mapping --- py4DSTEM/process/diffraction/crystal_ACOM.py | 102 ++++++++++++++----- 1 file changed, 78 insertions(+), 24 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 83dff29e1..f4a309aa8 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -2025,6 +2025,9 @@ def calculate_strain( tol_intensity: float = 1e-4, k_max: Optional[float] = None, min_num_peaks=5, + intensity_weighting = False, + robust = True, + robust_thresh = 3.0, rotation_range=None, mask_from_corr=True, corr_range=(0, 2), @@ -2039,24 +2042,46 @@ def calculate_strain( TODO: add robust fitting? - Args: - bragg_peaks_array (PointListArray): All Bragg peaks - orientation_map (OrientationMap): Orientation map generated from ACOM - corr_kernel_size (float): Correlation kernel size - if user does - not specify, uses self.corr_kernel_size. - sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): Maximum scattering vector - min_num_peaks (int): Minimum number of peaks required. - rotation_range (float): Maximum rotation range in radians (for symmetry reduction). - progress_bar (bool): Show progress bar - mask_from_corr (bool): Use ACOM correlation signal for mask - corr_range (np.ndarray): Range of correlation signals for mask - corr_normalize (bool): Normalize correlation signal before masking + Parameters + ---------- + bragg_peaks_array (PointListArray): + All Bragg peaks + orientation_map (OrientationMap): + Orientation map generated from ACOM + corr_kernel_size (float): + Correlation kernel size - if user does + not specify, uses self.corr_kernel_size. + sigma_excitation_error (float): + sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms + tol_excitation_error_mult (float): + tolerance in units of sigma for s_g inclusion + tol_intensity (np float): + tolerance in intensity units for inclusion of diffraction spots + k_max (float): + Maximum scattering vector + min_num_peaks (int): + Minimum number of peaks required. + intensity_weighting: bool + Set to True to weight least squares by experimental peak intensity. + robust_fitting: bool + Set to True to use robust fitting, which performs outlier rejection. + robust_thresh: float + Threshold for robust fitting weights. + rotation_range (float): + Maximum rotation range in radians (for symmetry reduction). + progress_bar (bool): + Show progress bar + mask_from_corr (bool): + Use ACOM correlation signal for mask + corr_range (np.ndarray): + Range of correlation signals for mask + corr_normalize (bool): + Normalize correlation signal before masking - Returns: - strain_map (RealSlice): strain tensor + Returns + -------- + strain_map (RealSlice): + strain tensor """ @@ -2098,6 +2123,8 @@ def calculate_strain( # Loop over all probe positions for rx, ry in tqdmnd( + # range(220,221), + # range(40,41), *bragg_peaks_array.shape, desc="Calculating strains", unit=" PointList", @@ -2143,16 +2170,43 @@ def calculate_strain( (p_ref.data["qx"][inds_match[keep]], p_ref.data["qy"][inds_match[keep]]) ).T - # Apply intensity weighting from experimental measurements - qxy *= p.data["intensity"][keep, None] - qxy_ref *= p.data["intensity"][keep, None] # Fit transformation matrix # Note - not sure about transpose here # (though it might not matter if rotation isn't included) - m = lstsq(qxy_ref, qxy, rcond=None)[0].T - - # Get the infinitesimal strain matrix + if intensity_weighting: + weights = np.sqrt(p.data["intensity"][keep, None])*0+1 + m = lstsq( + qxy_ref * weights, + qxy * weights, + rcond=None, + )[0].T + else: + m = lstsq( + qxy_ref, + qxy, + rcond=None, + )[0].T + + # Robust fitting + if robust: + for a0 in range(5): + # calculate new weights + qxy_fit = qxy_ref @ m + diff2 = np.sum((qxy_fit - qxy)**2,axis=1) + + weights = np.exp(diff2 / ((-2*robust_thresh**2)*np.median(diff2)))[:,None] + if intensity_weighting: + weights *= np.sqrt(p.data["intensity"][keep, None]) + + # calculate new fits + m = lstsq( + qxy_ref * weights, + qxy * weights, + rcond=None, + )[0].T + + # Set values into the infinitesimal strain matrix strain_map.get_slice("e_xx").data[rx, ry] = 1 - m[0, 0] strain_map.get_slice("e_yy").data[rx, ry] = 1 - m[1, 1] strain_map.get_slice("e_xy").data[rx, ry] = -(m[0, 1] + m[1, 0]) / 2.0 @@ -2160,7 +2214,7 @@ def calculate_strain( # Add finite rotation from ACOM orientation map. # I am not sure about the relative signs here. - # Also, I need to add in the mirror operator. + # Also, maybe I need to add in the mirror operator? if orientation_map.mirror[rx, ry, 0]: strain_map.get_slice("theta").data[rx, ry] += ( orientation_map.angles[rx, ry, 0, 0] From 771ab1f5626c2c8cdba8573050f7379d8c3a06af Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 13 Mar 2024 13:12:48 -0700 Subject: [PATCH 02/74] Updating matching --- py4DSTEM/process/diffraction/crystal_calibrate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_calibrate.py b/py4DSTEM/process/diffraction/crystal_calibrate.py index c068bf79e..b15015c62 100644 --- a/py4DSTEM/process/diffraction/crystal_calibrate.py +++ b/py4DSTEM/process/diffraction/crystal_calibrate.py @@ -21,7 +21,7 @@ def calibrate_pixel_size( k_max=None, k_step=0.002, k_broadening=0.002, - fit_all_intensities=True, + fit_all_intensities=False, set_calibration_in_place=False, verbose=True, plot_result=False, @@ -50,7 +50,7 @@ def calibrate_pixel_size( k_broadening (float): Initial guess for Gaussian broadening of simulated pattern (Å^-1) fit_all_intensities (bool): Set to true to allow all peak intensities to - change independently False forces a single intensity scaling. + change independently. False forces a single intensity scaling for all peaks. set_calibration (bool): if True, set the fit pixel size to the calibration metadata, and calibrate bragg_peaks verbose (bool): Output the calibrated pixel size. @@ -138,7 +138,7 @@ def fit_profile(k, *coefs): if returnfig: fig, ax = self.plot_scattering_intensity( - bragg_peaks=bragg_peaks, + bragg_peaks=bragg_peaks_cali, figsize=figsize, k_broadening=k_broadening, int_power_scale=1.0, @@ -151,7 +151,7 @@ def fit_profile(k, *coefs): ) else: self.plot_scattering_intensity( - bragg_peaks=bragg_peaks, + bragg_peaks=bragg_peaks_cali, figsize=figsize, k_broadening=k_broadening, int_power_scale=1.0, From 938c965535c044cb2adc9e111753137a58417d68 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 20 Mar 2024 15:07:40 -0700 Subject: [PATCH 03/74] Fixing thickness projection in 2D potentials --- py4DSTEM/process/diffraction/crystal.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 8b59c3c92..85f8160b7 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1036,6 +1036,7 @@ def generate_projected_potential( """ + # Determine image size in Angstroms im_size = np.array(im_size) im_size_Ang = im_size * pixel_size_angstroms @@ -1123,11 +1124,15 @@ def generate_projected_potential( # Thickness if thickness_angstroms > 0: - thickness_proj = thickness_angstroms / m_proj[2] - vec_proj = thickness_proj / pixel_size_angstroms * m_proj[:2] - num_proj = (np.ceil(np.linalg.norm(vec_proj))+1).astype('int') - x_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[0] - y_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[1] + num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype('int') + if num_proj > 1: + vec_proj = m_proj[:2] / pixel_size_angstroms + shifts = np.arange(num_proj).astype('float') + shifts -= np.mean(shifts) + x_proj = shifts * vec_proj[0] + y_proj = shifts * vec_proj[1] + else: + num_proj = 1 # initialize potential im_potential = np.zeros(im_size) @@ -1136,7 +1141,7 @@ def generate_projected_potential( for a0 in range(atoms_ID_all.shape[0]): ind = np.argmin(np.abs(atoms_ID - atoms_ID_all[a0])) - if thickness_angstroms > 0: + if num_proj > 1: for a1 in range(num_proj): x_ind = np.round(x[a0]+x_proj[a1]).astype('int') + R_ind y_ind = np.round(y[a0]+y_proj[a1]).astype('int') + R_ind From 6938fe375a60f4e8c99ddee550120673cba433fb Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 20 Mar 2024 15:08:23 -0700 Subject: [PATCH 04/74] black --- py4DSTEM/process/diffraction/crystal.py | 168 +++++++++--------- py4DSTEM/process/diffraction/crystal_ACOM.py | 55 +++--- py4DSTEM/process/utils/single_atom_scatter.py | 10 +- 3 files changed, 118 insertions(+), 115 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 85f8160b7..dc50be154 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -968,19 +968,16 @@ def generate_ring_pattern( if return_calc is True: return radii_unique, intensity_unique - - - def generate_projected_potential( self, - im_size = (256,256), - pixel_size_angstroms = 0.1, - potential_radius_angstroms = 3.0, - sigma_image_blur_angstroms = 0.1, - thickness_angstroms = 100, - power_scale = 1.0, - plot_result = False, - figsize = (6,6), + im_size=(256, 256), + pixel_size_angstroms=0.1, + potential_radius_angstroms=3.0, + sigma_image_blur_angstroms=0.1, + thickness_angstroms=100, + power_scale=1.0, + plot_result=False, + figsize=(6, 6), orientation: Optional[Orientation] = None, ind_orientation: Optional[int] = 0, orientation_matrix: Optional[np.ndarray] = None, @@ -1004,7 +1001,7 @@ def generate_projected_potential( sigma_image_blur_angstroms: float Image blurring in Angstroms. thickness_angstroms: float - Thickness of the sample in Angstroms. + Thickness of the sample in Angstroms. Set thickness_thickness_angstroms = 0 to skip thickness projection. power_scale: float Power law scaling of potentials. Set to 2.0 to approximate Z^2 images. @@ -1012,31 +1009,30 @@ def generate_projected_potential( Plot the projected potential image. figsize: (2,) vector giving the size of the output. - - orientation: Orientation + + orientation: Orientation An Orientation class object ind_orientation: int If input is an Orientation class object with multiple orientations, this input can be used to select a specific orientation. - orientation_matrix: array + orientation_matrix: array (3,3) orientation matrix, where columns represent projection directions. - zone_axis_lattice: array + zone_axis_lattice: array (3,) projection direction in lattice indices - proj_x_lattice: array) + proj_x_lattice: array) (3,) x-axis direction in lattice indices - zone_axis_cartesian: array + zone_axis_cartesian: array (3,) cartesian projection direction - proj_x_cartesian: array + proj_x_cartesian: array (3,) cartesian projection direction Returns -------- - im_potential: (np.array) + im_potential: (np.array) Output image of the projected potential. """ - # Determine image size in Angstroms im_size = np.array(im_size) im_size_Ang = im_size * pixel_size_angstroms @@ -1056,78 +1052,78 @@ def generate_projected_potential( lat_real = self.lat_real.copy() @ orientation_matrix # Determine unit cell axes to tile over, by selecting 2/3 with largest in-plane component - inds_tile = np.argsort( - np.linalg.norm(lat_real[:,0:2],axis=1) - )[1:3] - m_tile = lat_real[inds_tile,:] + inds_tile = np.argsort(np.linalg.norm(lat_real[:, 0:2], axis=1))[1:3] + m_tile = lat_real[inds_tile, :] # Vector projected along optic axis - m_proj = np.squeeze(np.delete(lat_real,inds_tile,axis=0)) + m_proj = np.squeeze(np.delete(lat_real, inds_tile, axis=0)) # Determine tiling range - p_corners = np.array([ - [-im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], - [ im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], - [-im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], - [ im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], - ]) - ab = np.linalg.lstsq( - m_tile[:,:2].T, - p_corners[:,:2].T, - rcond=None - )[0] + p_corners = np.array( + [ + [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + ] + ) + ab = np.linalg.lstsq(m_tile[:, :2].T, p_corners[:, :2].T, rcond=None)[0] ab = np.floor(ab) - a_range = np.array((np.min(ab[0])-1,np.max(ab[0])+2)) - b_range = np.array((np.min(ab[1])-1,np.max(ab[1])+2)) + a_range = np.array((np.min(ab[0]) - 1, np.max(ab[0]) + 2)) + b_range = np.array((np.min(ab[1]) - 1, np.max(ab[1]) + 2)) # Tile unit cell a_ind, b_ind, atoms_ind = np.meshgrid( - np.arange(a_range[0],a_range[1]), - np.arange(b_range[0],b_range[1]), + np.arange(a_range[0], a_range[1]), + np.arange(b_range[0], b_range[1]), np.arange(self.positions.shape[0]), ) - abc_atoms = self.positions[atoms_ind.ravel(),:] - abc_atoms[:,inds_tile[0]] += a_ind.ravel() - abc_atoms[:,inds_tile[1]] += b_ind.ravel() + abc_atoms = self.positions[atoms_ind.ravel(), :] + abc_atoms[:, inds_tile[0]] += a_ind.ravel() + abc_atoms[:, inds_tile[1]] += b_ind.ravel() xyz_atoms_ang = abc_atoms @ lat_real atoms_ID_all = self.numbers[atoms_ind.ravel()] # Center atoms on image plane - x = xyz_atoms_ang[:,0] / pixel_size_angstroms + im_size[0]/2.0 - y = xyz_atoms_ang[:,1] / pixel_size_angstroms + im_size[1]/2.0 - atoms_del = np.logical_or.reduce(( - x <= -potential_radius_angstroms/2, - y <= -potential_radius_angstroms/2, - x >= im_size[0] + potential_radius_angstroms/2, - y >= im_size[1] + potential_radius_angstroms/2, - )) + x = xyz_atoms_ang[:, 0] / pixel_size_angstroms + im_size[0] / 2.0 + y = xyz_atoms_ang[:, 1] / pixel_size_angstroms + im_size[1] / 2.0 + atoms_del = np.logical_or.reduce( + ( + x <= -potential_radius_angstroms / 2, + y <= -potential_radius_angstroms / 2, + x >= im_size[0] + potential_radius_angstroms / 2, + y >= im_size[1] + potential_radius_angstroms / 2, + ) + ) x = np.delete(x, atoms_del) y = np.delete(y, atoms_del) atoms_ID_all = np.delete(atoms_ID_all, atoms_del) # Coordinate system for atomic projected potentials potential_radius = np.ceil(potential_radius_angstroms / pixel_size_angstroms) - R = np.arange(0.5-potential_radius,potential_radius+0.5) - R_ind = R.astype('int') - R_2D = np.sqrt(R[:,None]**2 + R[None,:]**2) + R = np.arange(0.5 - potential_radius, potential_radius + 0.5) + R_ind = R.astype("int") + R_2D = np.sqrt(R[:, None] ** 2 + R[None, :] ** 2) # Lookup table for atomic projected potentials atoms_ID = np.unique(self.numbers) - atoms_lookup = np.zeros(( - atoms_ID.shape[0], - R_2D.shape[0], - R_2D.shape[1], - )) + atoms_lookup = np.zeros( + ( + atoms_ID.shape[0], + R_2D.shape[0], + R_2D.shape[1], + ) + ) for a0 in range(atoms_ID.shape[0]): atom_sf = single_atom_scatter([atoms_ID[a0]]) - atoms_lookup[a0,:,:] = atom_sf.projected_potential(atoms_ID[a0], R_2D) + atoms_lookup[a0, :, :] = atom_sf.projected_potential(atoms_ID[a0], R_2D) atoms_lookup **= power_scale - # Thickness + # Thickness if thickness_angstroms > 0: - num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype('int') + num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype("int") if num_proj > 1: vec_proj = m_proj[:2] / pixel_size_angstroms - shifts = np.arange(num_proj).astype('float') + shifts = np.arange(num_proj).astype("float") shifts -= np.mean(shifts) x_proj = shifts * vec_proj[0] y_proj = shifts * vec_proj[1] @@ -1143,8 +1139,8 @@ def generate_projected_potential( if num_proj > 1: for a1 in range(num_proj): - x_ind = np.round(x[a0]+x_proj[a1]).astype('int') + R_ind - y_ind = np.round(y[a0]+y_proj[a1]).astype('int') + R_ind + x_ind = np.round(x[a0] + x_proj[a1]).astype("int") + R_ind + y_ind = np.round(y[a0] + y_proj[a1]).astype("int") + R_ind x_sub = np.logical_and( x_ind >= 0, x_ind < im_size[0], @@ -1154,11 +1150,13 @@ def generate_projected_potential( y_ind < im_size[1], ) - im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] - + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] + else: - x_ind = np.round(x[a0]).astype('int') + R_ind - y_ind = np.round(y[a0]).astype('int') + R_ind + x_ind = np.round(x[a0]).astype("int") + R_ind + y_ind = np.round(y[a0]).astype("int") + R_ind x_sub = np.logical_and( x_ind >= 0, x_ind < im_size[0], @@ -1168,7 +1166,9 @@ def generate_projected_potential( y_ind < im_size[1], ) - im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] if thickness_angstroms > 0: im_potential /= num_proj @@ -1179,26 +1179,28 @@ def generate_projected_potential( im_potential = gaussian_filter( im_potential, sigma_image_blur, - mode = 'nearest', - ) + mode="nearest", + ) if plot_result: # quick plotting of the result int_vals = np.sort(im_potential.ravel()) - int_range = np.array(( - int_vals[np.round(0.02*int_vals.size).astype('int')], - int_vals[np.round(0.98*int_vals.size).astype('int')], - )) + int_range = np.array( + ( + int_vals[np.round(0.02 * int_vals.size).astype("int")], + int_vals[np.round(0.98 * int_vals.size).astype("int")], + ) + ) - fig,ax = plt.subplots(figsize = figsize) + fig, ax = plt.subplots(figsize=figsize) ax.imshow( im_potential, - cmap = 'turbo', - vmin = int_range[0], - vmax = int_range[1], - ) + cmap="turbo", + vmin=int_range[0], + vmax=int_range[1], + ) ax.set_axis_off() - ax.set_aspect('equal') + ax.set_aspect("equal") return im_potential diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index f4a309aa8..a0f8468ed 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -2025,9 +2025,9 @@ def calculate_strain( tol_intensity: float = 1e-4, k_max: Optional[float] = None, min_num_peaks=5, - intensity_weighting = False, - robust = True, - robust_thresh = 3.0, + intensity_weighting=False, + robust=True, + robust_thresh=3.0, rotation_range=None, mask_from_corr=True, corr_range=(0, 2), @@ -2044,22 +2044,22 @@ def calculate_strain( Parameters ---------- - bragg_peaks_array (PointListArray): + bragg_peaks_array (PointListArray): All Bragg peaks - orientation_map (OrientationMap): + orientation_map (OrientationMap): Orientation map generated from ACOM - corr_kernel_size (float): + corr_kernel_size (float): Correlation kernel size - if user does not specify, uses self.corr_kernel_size. - sigma_excitation_error (float): + sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): + tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): + tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): + k_max (float): Maximum scattering vector - min_num_peaks (int): + min_num_peaks (int): Minimum number of peaks required. intensity_weighting: bool Set to True to weight least squares by experimental peak intensity. @@ -2067,20 +2067,20 @@ def calculate_strain( Set to True to use robust fitting, which performs outlier rejection. robust_thresh: float Threshold for robust fitting weights. - rotation_range (float): + rotation_range (float): Maximum rotation range in radians (for symmetry reduction). - progress_bar (bool): + progress_bar (bool): Show progress bar - mask_from_corr (bool): + mask_from_corr (bool): Use ACOM correlation signal for mask - corr_range (np.ndarray): + corr_range (np.ndarray): Range of correlation signals for mask - corr_normalize (bool): + corr_normalize (bool): Normalize correlation signal before masking Returns -------- - strain_map (RealSlice): + strain_map (RealSlice): strain tensor """ @@ -2170,21 +2170,20 @@ def calculate_strain( (p_ref.data["qx"][inds_match[keep]], p_ref.data["qy"][inds_match[keep]]) ).T - # Fit transformation matrix # Note - not sure about transpose here # (though it might not matter if rotation isn't included) if intensity_weighting: - weights = np.sqrt(p.data["intensity"][keep, None])*0+1 + weights = np.sqrt(p.data["intensity"][keep, None]) * 0 + 1 m = lstsq( - qxy_ref * weights, - qxy * weights, + qxy_ref * weights, + qxy * weights, rcond=None, )[0].T else: m = lstsq( - qxy_ref, - qxy, + qxy_ref, + qxy, rcond=None, )[0].T @@ -2193,16 +2192,18 @@ def calculate_strain( for a0 in range(5): # calculate new weights qxy_fit = qxy_ref @ m - diff2 = np.sum((qxy_fit - qxy)**2,axis=1) + diff2 = np.sum((qxy_fit - qxy) ** 2, axis=1) - weights = np.exp(diff2 / ((-2*robust_thresh**2)*np.median(diff2)))[:,None] + weights = np.exp( + diff2 / ((-2 * robust_thresh**2) * np.median(diff2)) + )[:, None] if intensity_weighting: weights *= np.sqrt(p.data["intensity"][keep, None]) # calculate new fits m = lstsq( - qxy_ref * weights, - qxy * weights, + qxy_ref * weights, + qxy * weights, rcond=None, )[0].T diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 9085afd93..90397560f 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -2,6 +2,7 @@ import os from scipy.special import kn + class single_atom_scatter(object): """ This class calculates the composition averaged single atom scattering factor for a @@ -58,18 +59,17 @@ def projected_potential(self, Z, R): qe = 1.60217662e-19 # Permittivity of vacuum eps_0 = 8.85418782e-12 - # Bohr's constant + # Bohr's constant a_0 = 5.29177210903e-11 fe = np.zeros_like(R) for i in range(5): # fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 - pre = 2*np.pi/bi[i]**0.5 - fe += (ai[i] / bi[i]**1.5) * \ - (kn(0, pre * R) + R * kn(1, pre * R)) + pre = 2 * np.pi / bi[i] ** 0.5 + fe += (ai[i] / bi[i] ** 1.5) * (kn(0, pre * R) + R * kn(1, pre * R)) # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) - return fe * 2 * np.pi**2# / kappa + return fe * 2 * np.pi**2 # / kappa # if units == "VA": # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe # elif units == "A": From 26508318ebc2b39610bab24f56db8767f7167203 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 11:27:45 -0700 Subject: [PATCH 05/74] Remove testing lines. --- py4DSTEM/process/diffraction/crystal_ACOM.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index a0f8468ed..c2bac6424 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -2123,8 +2123,6 @@ def calculate_strain( # Loop over all probe positions for rx, ry in tqdmnd( - # range(220,221), - # range(40,41), *bragg_peaks_array.shape, desc="Calculating strains", unit=" PointList", From 1e700803684514b5bb8589b32600016812b8d0af Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 13:53:44 -0700 Subject: [PATCH 06/74] Trying (and failing) to figure out the potential units --- py4DSTEM/process/diffraction/crystal.py | 3 +++ py4DSTEM/process/utils/single_atom_scatter.py | 21 ++++++++++++------- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index dc50be154..97d596ec4 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -989,6 +989,9 @@ def generate_projected_potential( """ Generate an image of the projected potential of crystal in real space, using cell tiling, and a lookup table of the atomic potentials. + Note that we round atomic positions to the nearest pixel for speed. + + TODO - fix scattering prefactor so that output units are sensible. Parameters ---------- diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 90397560f..54443b68f 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -57,6 +57,8 @@ def projected_potential(self, Z, R): me = 9.10938356e-31 # Electron charge in Coulomb qe = 1.60217662e-19 + # Electron charge in V-Angstroms + # qe = 14.4 # Permittivity of vacuum eps_0 = 8.85418782e-12 # Bohr's constant @@ -64,16 +66,21 @@ def projected_potential(self, Z, R): fe = np.zeros_like(R) for i in range(5): - # fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 pre = 2 * np.pi / bi[i] ** 0.5 fe += (ai[i] / bi[i] ** 1.5) * (kn(0, pre * R) + R * kn(1, pre * R)) - # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) - return fe * 2 * np.pi**2 # / kappa - # if units == "VA": - # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe - # elif units == "A": - # return fe * 2 * np.pi**2 / kappa + # Scale output units + # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*qe) + # fe *= 2*np.pi**2 / kappa + # # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + + # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + # # return fe * 2 * np.pi**2 # / kappa + # # if units == "VA": + # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe + # # elif units == "A": + # # return fe * 2 * np.pi**2 / kappa + return fe def get_scattering_factor( self, elements=None, composition=None, q_coords=None, units=None From c8a81f557d794f0083d9c09bcd8d3d0def52c74f Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 13:54:16 -0700 Subject: [PATCH 07/74] Black formatting --- .../diffraction/WK_scattering_factors.py | 4 ++- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 +++++++++---------- .../process/phase/magnetic_ptychography.py | 28 +++++++++---------- py4DSTEM/process/phase/parallax.py | 24 ++++++++-------- .../process/phase/ptychographic_methods.py | 4 ++- .../process/phase/ptychographic_tomography.py | 28 +++++++++---------- py4DSTEM/process/phase/utils.py | 4 ++- py4DSTEM/process/utils/utils.py | 18 ++++++++++-- 10 files changed, 82 insertions(+), 62 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index eb964de96..70110a977 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,7 +221,9 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( + BJ / (BI + BJ) + ) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 47df2e6ca..da016b3ed 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,7 +454,8 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) + * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 5c2d56a3c..9973ff79f 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,7 +169,8 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index c9efae806..0a58a514b 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1188,20 +1188,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 2e887739f..67e315234 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1490,20 +1490,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index aa9323900..79392f6e6 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2195,16 +2195,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 0] - ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 1] - ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2215,16 +2215,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 0] - ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 1] - ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index fa0b1db9f..cb109073e 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -351,7 +351,9 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index f3b2991ab..8b4d4d984 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1080,20 +1080,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index a5a541795..f53637184 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,7 +203,9 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) + return xp.exp( + -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 + ) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index ddeeb2c36..60da616d1 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,7 +93,12 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) return lam @@ -102,8 +107,15 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 - sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) + sigma = ( + (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + ) return sigma From e362b6b197cfd75c12da30a70d27e719704a4c0e Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 14:03:20 -0700 Subject: [PATCH 08/74] Black again --- .../diffraction/WK_scattering_factors.py | 4 +-- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 +++++++++---------- .../process/phase/magnetic_ptychography.py | 28 +++++++++---------- py4DSTEM/process/phase/parallax.py | 24 ++++++++-------- .../process/phase/ptychographic_methods.py | 4 +-- .../process/phase/ptychographic_tomography.py | 28 +++++++++---------- py4DSTEM/process/phase/utils.py | 4 +-- py4DSTEM/process/utils/utils.py | 18 ++---------- 10 files changed, 62 insertions(+), 82 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index 70110a977..eb964de96 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,9 +221,7 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( - BJ / (BI + BJ) - ) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index da016b3ed..47df2e6ca 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,8 +454,7 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) - * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 9973ff79f..5c2d56a3c 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,8 +169,7 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) - + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 0a58a514b..c9efae806 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1188,20 +1188,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 67e315234..2e887739f 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1490,20 +1490,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index d366a01df..060a151aa 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2195,16 +2195,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 0] + measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 0] + ) measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 1] + measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 1] + ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2215,16 +2215,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 0] + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 0] + ) fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 1] + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 1] + ) max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index cb109073e..fa0b1db9f 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -351,9 +351,7 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) + propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 8b4d4d984..f3b2991ab 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1080,20 +1080,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index f53637184..a5a541795 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,9 +203,7 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp( - -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 - ) + return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 60da616d1..ddeeb2c36 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,12 +93,7 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 return lam @@ -107,15 +102,8 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) - sigma = ( - (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) return sigma From 0d772ad91a5fbfbe5ecc4e512b1a3e829288ceaf Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 8 Apr 2024 21:21:28 -0700 Subject: [PATCH 09/74] updating phase mapping --- py4DSTEM/process/diffraction/crystal_phase.py | 1353 ++++++++++++++--- 1 file changed, 1122 insertions(+), 231 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index d28616aa9..508a01c79 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -3,12 +3,14 @@ from scipy.optimize import nnls import matplotlib as mpl import matplotlib.pyplot as plt +from scipy.ndimage import gaussian_filter from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern + class Crystal_Phase: """ A class storing multiple crystal structures, and associated diffraction data. @@ -19,8 +21,9 @@ class Crystal_Phase: def __init__( self, crystals, - orientation_maps, - name, + crystal_names = None, + orientation_maps = None, + name = None, ): """ Args: @@ -33,76 +36,964 @@ def __init__( self.num_crystals = len(crystals) else: raise TypeError("crystals must be a list of crystal instances.") - if isinstance(orientation_maps, list): + + # List of orientation maps + if orientation_maps is None: + self.orientation_maps = [crystals[ind].orientation_map for ind in range(self.num_crystals)] + else: if len(self.crystals) != len(orientation_maps): raise ValueError( "Orientation maps must have the same number of entries as crystals." ) self.orientation_maps = orientation_maps + + # Names of all crystal phases + if crystal_names is None: + self.crystal_names = ['crystal' + str(ind) for ind in range(self.num_crystals)] else: - raise TypeError("orientation_maps must be a list of orientation maps.") - self.name = name - return + self.crystal_names = crystal_names + + # Name of the phase map + if name is None: + self.name = 'phase map' + else: + self.name = name + + # Get some attributes from crystals + self.k_max = np.zeros(self.num_crystals) + self.num_matches = np.zeros(self.num_crystals, dtype='int') + self.crystal_identity = np.zeros((0,2), dtype='int') + for a0 in range(self.num_crystals): + self.k_max[a0] = self.crystals[a0].k_max + self.num_matches[a0] = self.crystals[a0].orientation_map.num_matches + for a1 in range(self.num_matches[a0]): + self.crystal_identity = np.append(self.crystal_identity,np.array((a0,a1),dtype='int')[None,:], axis=0) - def plot_all_phase_maps(self, map_scale_values=None, index=0): + self.num_fits = np.sum(self.num_matches) + + + + def quantify_single_pattern( + self, + pointlistarray: PointListArray, + xy_position = (0,0), + corr_kernel_size = 0.04, + allow_strain = True, + # corr_distance_scale = 1.0, + include_false_positives = True, + weight_false_positives = 1.0, + max_number_phases = 1, + sigma_excitation_error = 0.04, + power_experiment = 0.25, + power_calculated = 0.25, + plot_result = True, + plot_only_nonzero_phases = True, + plot_unmatched_peaks = False, + plot_correlation_radius = False, + scale_markers_experiment = 40, + scale_markers_calculated = 500, + crystal_inds_plot = None, + phase_colors = None, + figsize = (10,7), + verbose = True, + returnfig = False, + ): + """ + Quantify the phase for a single diffraction pattern. """ - Visualize phase maps of dataset. - Args: - map_scale_values (float): Value to scale correlations by + # tolerance + tol2 = 4e-4 + + # calibrations + center = pointlistarray.calstate['center'] + ellipse = pointlistarray.calstate['ellipse'] + pixel = pointlistarray.calstate['pixel'] + rotate = pointlistarray.calstate['rotate'] + if center is False: + raise ValueError('Bragg peaks must be center calibration') + if pixel is False: + raise ValueError('Bragg peaks must have pixel size calibration') + # TODO - potentially warn the user if ellipse / rotate calibration not available + + if phase_colors is None: + phase_colors = np.array(( + (1.0,0.0,0.0,1.0), + (0.0,0.8,1.0,1.0), + (0.0,0.6,0.0,1.0), + (1.0,0.0,1.0,1.0), + (0.0,0.2,1.0,1.0), + (1.0,0.8,0.0,1.0), + )) + + # Experimental values + bragg_peaks = pointlistarray.get_vectors( + xy_position[0], + xy_position[1], + center = center, + ellipse = ellipse, + pixel = pixel, + rotate = rotate) + # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() + keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + # ind_center_beam = np.argmin( + # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) + # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') + # mask[ind_center_beam] = False + # bragg_peaks.remove(ind_center_beam) + qx = bragg_peaks.data["qx"][keep] + qy = bragg_peaks.data["qy"][keep] + qx0 = bragg_peaks.data["qx"][np.logical_not(keep)] + qy0 = bragg_peaks.data["qy"][np.logical_not(keep)] + if power_experiment == 0: + intensity = np.ones_like(qx) + intensity0 = np.ones_like(qx0) + else: + intensity = bragg_peaks.data["intensity"][keep]**power_experiment + intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_experiment + int_total = np.sum(intensity) + + # init basis array + if include_false_positives: + basis = np.zeros((intensity.shape[0], self.num_fits)) + unpaired_peaks = [] + else: + basis = np.zeros((intensity.shape[0], self.num_fits)) + if allow_strain: + m_strains = np.zeros((self.num_fits,2,2)) + + # kernel radius squared + radius_max_2 = corr_kernel_size**2 + + # init for plotting + if plot_result: + library_peaks = [] + library_int = [] + library_matches = [] + + # Generate point list data, match to experimental peaks + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + # for c in range(self.num_crystals): + # for m in range(self.num_matches[c]): + # ind_match += 1 + + # Generate simulated peaks + bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + self.crystals[c].orientation_map.get_orientation( + xy_position[0], xy_position[1] + ), + ind_orientation = m, + sigma_excitation_error = sigma_excitation_error, + ) + del_peak = bragg_peaks_fit.data["qx"]**2 \ + + bragg_peaks_fit.data["qy"]**2 < tol2 + bragg_peaks_fit.remove(del_peak) + + # peak intensities + if power_calculated == 0: + int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) + else: + int_fit = bragg_peaks_fit.data['intensity']**power_calculated + + # Pair peaks to experiment + if plot_result: + matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') + + if allow_strain: + # Initial peak pairing to find best-fit strain distortion + pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') + pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + pair_sub[a1] = True + pair_inds[a1] = ind_min + + # calculate best-fit strain tensor, weighted by the intensities. + # requires at least 4 peak pairs + if np.sum(pair_sub) >= 4: + pair_basis = np.vstack(( + qx[pair_inds[pair_sub]], + qy[pair_inds[pair_sub]], + )).T + pair_obs = np.vstack(( + bragg_peaks_fit.data['qx'][pair_sub], + bragg_peaks_fit.data['qy'][pair_sub], + )).T + + # weights + dists = np.sqrt( + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) + weights = np.sqrt( + int_fit[pair_sub] * intensity[pair_inds[pair_sub]] + ) * (1 - dists / corr_kernel_size) + + # wtrain tensor + m_strain = np.linalg.lstsq( + pair_basis * weights[:,None], + pair_obs * weights[:,None], + rcond = None, + )[0] + m_strains[a0] = m_strain + + # Transformed peak positions + qx_copy = qx.copy() + qy_copy = qy.copy() + qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) + # print(np.round(dist,4)) + + # Loop over all peaks, pair experiment to library + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + # weight = 1 + corr_distance_scale * \ + # np.sqrt(dist2[ind_min]) / corr_kernel_size + # basis[ind_min,a0] = weight * int_fit[a1] + basis[ind_min,a0] = int_fit[a1] + if plot_result: + matches[a1] = True + elif include_false_positives: + # unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) + unpaired_peaks.append([a0,int_fit[a1]]) + + if plot_result: + library_peaks.append(bragg_peaks_fit) + library_int.append(int_fit) + library_matches.append(matches) + + # If needed, augment basis and observations with false positives + if include_false_positives: + basis_aug = np.zeros((len(unpaired_peaks),self.num_fits)) + for a0 in range(len(unpaired_peaks)): + basis_aug[a0,unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] + + basis = np.vstack((basis, basis_aug * weight_false_positives)) + obs = np.hstack((intensity, np.zeros(len(unpaired_peaks)))) + + else: + obs = intensity + + # Solve for phase coefficients + try: + phase_weights = np.zeros(self.num_fits) + inds_solve = np.ones(self.num_fits,dtype='bool') + + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_phases: + phase_weights[inds_solve] = phase_weights_cand + phase_residual = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + + except: + phase_weights = np.zeros(self.num_fits) + phase_residual = np.sqrt(np.sum(intensity**2)) + + if verbose: + ind_max = np.argmax(phase_weights) + # print() + print('\033[1m' + 'phase_weight or_ind name' + '\033[0m') + # print() + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + line = '{:>12} {:>8} {:<12}'.format( + np.round(phase_weights[a0],decimals=2), + m, + self.crystal_names[c] + ) + if a0 == ind_max: + print('\033[1m' + line + '\033[0m') + else: + print(line) + # print() + + # Plotting + if plot_result: + # fig, ax = plt.subplots(figsize=figsize) + fig = plt.figure(figsize=figsize) + # if plot_layout == 0: + # ax_x = fig.add_axes( + # [0.0+figbound[0], 0.0, 0.4-2*+figbound[0], 1.0]) + ax = fig.add_axes([0.0, 0.0, 0.66, 1.0]) + ax_leg = fig.add_axes([0.68, 0.0, 0.3, 1.0]) + + if plot_correlation_radius: + # plot the experimental radii + t = np.linspace(0,2*np.pi,91,endpoint=True) + ct = np.cos(t) * corr_kernel_size + st = np.sin(t) * corr_kernel_size + for a0 in range(qx.shape[0]): + ax.plot( + qy[a0] + st, + qx[a0] + ct, + color = 'k', + linewidth = 1, + ) + + # plot the experimental peaks + ax.scatter( + qy0, + qx0, + s = scale_markers_experiment * intensity0, + marker = "o", + facecolor = [0.7, 0.7, 0.7], + ) + ax.scatter( + qy, + qx, + s = scale_markers_experiment * intensity, + marker = "o", + facecolor = [0.7, 0.7, 0.7], + ) + # legend + k_max = np.max(self.k_max) + dx_leg = -0.05*k_max + dy_leg = 0.04*k_max + text_params = { + "va": "center", + "ha": "left", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 14, + } + if plot_correlation_radius: + ax_leg.plot( + 0 + st*0.5, + -dx_leg + ct*0.5, + color = 'k', + linewidth = 1, + ) + ax_leg.scatter( + 0, + 0, + s = 200, + marker = "o", + facecolor = [0.7, 0.7, 0.7], + ) + ax_leg.text( + dy_leg, + 0, + 'Experimental peaks', + **text_params) + if plot_correlation_radius: + ax_leg.text( + dy_leg, + -dx_leg, + 'Correlation radius', + **text_params) + + + # plot calculated diffraction patterns + uvals = phase_colors.copy() + uvals[:,3] = 0.3 + # uvals = np.array(( + # (1.0,0.0,0.0,0.2), + # (0.0,0.8,1.0,0.2), + # (0.0,0.6,0.0,0.2), + # (1.0,0.0,1.0,0.2), + # (0.0,0.2,1.0,0.2), + # (1.0,0.8,0.0,0.2), + # )) + mvals = ['v','^','<','>','d','s',] + + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + + if crystal_inds_plot == None or np.min(np.abs(c - crystal_inds_plot)) == 0: + + qx_fit = library_peaks[a0].data['qx'] + qy_fit = library_peaks[a0].data['qy'] + + # if allow_strain: + # m_strain = m_strains[a0] + # # Transformed peak positions + # qx_copy = qx_fit.copy() + # qy_copy = qy_fit.copy() + # qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + # qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + int_fit = library_int[a0] + matches_fit = library_matches[a0] + + if plot_only_nonzero_phases is False or phase_weights[a0] > 0: + + if np.mod(m,2) == 0: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + if plot_unmatched_peaks: + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + else: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # facecolors = (1,1,1,0.5), + linewidth = 2, + ) + if plot_unmatched_peaks: + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (1,1,1,0.5), + linewidth = 2, + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # facecolors = (1,1,1,0.5), + ) + + # legend text + ax_leg.text( + dy_leg, + (a0+1)*dx_leg, + self.crystal_names[c], + **text_params) + + + # appearance + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + + ax_leg.set_xlim((-0.1*k_max, 0.4*k_max)) + ax_leg.set_ylim((-0.5*k_max, 0.5*k_max)) + ax_leg.set_axis_off() + + if returnfig: + return phase_weights, phase_residual, int_total, fig, ax + else: + return phase_weights, phase_residual, int_total + + def quantify_phase( + self, + pointlistarray: PointListArray, + corr_kernel_size = 0.04, + # corr_distance_scale = 1.0, + allow_strain = True, + include_false_positives = True, + weight_false_positives = 1.0, + max_number_phases = 2, + sigma_excitation_error = 0.02, + power_experiment = 0.25, + power_calculated = 0.25, + progress_bar = True, + ): """ - phase_maps = [] - if map_scale_values is None: - map_scale_values = [1] * len(self.orientation_maps) - corr_sum = np.sum( - [ - (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m]) - for m in range(len(self.orientation_maps)) - ] - ) - for m in range(len(self.orientation_maps)): - phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum) - show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno") - return - - def plot_phase_map(self, index=0, cmap=None): - corr_array = np.dstack( - [maps.corr[:, :, index] for maps in self.orientation_maps] + Quantify phase of all diffraction patterns. + """ + + # init results arrays + self.phase_weights = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + self.num_fits, + )) + self.phase_residuals = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) + self.int_total = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) + + for rx, ry in tqdmnd( + *pointlistarray.shape, + desc="Matching Orientations", + unit=" PointList", + disable=not progress_bar, + ): + # calculate phase weights + phase_weights, phase_residual, int_peaks = self.quantify_single_pattern( + pointlistarray = pointlistarray, + xy_position = (rx,ry), + corr_kernel_size = corr_kernel_size, + allow_strain = allow_strain, + # corr_distance_scale = corr_distance_scale, + include_false_positives = include_false_positives, + weight_false_positives = weight_false_positives, + max_number_phases = max_number_phases, + sigma_excitation_error = sigma_excitation_error, + power_experiment = power_experiment, + power_calculated = power_calculated, + plot_result = False, + verbose = False, + returnfig = False, + ) + self.phase_weights[rx,ry] = phase_weights + self.phase_residuals[rx,ry] = phase_residual + self.int_total[rx,ry] = int_peaks + + + def plot_phase_weights( + self, + weight_range = (0.0,1.0), + weight_normalize = False, + total_intensity_normalize = True, + cmap = 'gray', + show_ticks = False, + show_axes = True, + layout = 0, + figsize = (6,6), + returnfig = False, + ): + """ + Plot the individual phase weight maps and residuals. + """ + + # Normalization if required to total DF peak intensity + phase_weights = self.phase_weights.copy() + phase_residuals = self.phase_residuals.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:,:,a0][sub] /= self.int_total[sub] + phase_residuals[sub] /= self.int_total[sub] + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(phase_weights,axis=2)) + else: + scale = 1 + weight_range = np.array(weight_range) * scale + + # plotting + if layout == 0: + fig,ax = plt.subplots( + 1, + self.num_crystals + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + elif layout == 1: + fig,ax = plt.subplots( + self.num_crystals + 1, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + im = np.sum(phase_weights[:,:,sub],axis=2) + im = np.clip( + (im - weight_range[0]) / (weight_range[1] - weight_range[0]), + 0,1) + ax[a0].imshow( + im, + vmin = 0, + vmax = 1, + cmap = cmap, + ) + ax[a0].set_title( + self.crystal_names[a0], + fontsize = 16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # plot residuals + im = np.clip( + (phase_residuals - weight_range[0]) \ + / (weight_range[1] - weight_range[0]), + 0,1) + ax[self.num_crystals].imshow( + im, + vmin = 0, + vmax = 1, + cmap = cmap, ) - best_corr_score = np.max(corr_array, axis=2) - best_match_phase = [ - np.where(corr_array[:, :, p] == best_corr_score, True, False) - for p in range(len(self.orientation_maps)) - ] - - if cmap is None: - cm = plt.get_cmap("rainbow") - cmap = [ - cm(1.0 * i / len(self.orientation_maps)) - for i in range(len(self.orientation_maps)) - ] - - fig, (ax) = plt.subplots(figsize=(6, 6)) - ax.matshow( - np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), - cmap="gray", + ax[self.num_crystals].set_title( + 'Residuals', + fontsize = 16, ) - ax.axis("off") - - for m in range(len(self.orientation_maps)): - c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[ - m - ] - cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10) - ax.matshow( - np.ma.array( - self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m] - ), - cmap=cm, + if not show_ticks: + ax[self.num_crystals].set_xticks([]) + ax[self.num_crystals].set_yticks([]) + if not show_axes: + ax[self.num_crystals].set_axis_off() + + if returnfig: + return fig, ax + + + def plot_phase_maps( + self, + weight_threshold = 0.5, + weight_normalize = True, + total_intensity_normalize = True, + plot_combine = False, + crystal_inds_plot = None, + phase_colors = None, + show_ticks = False, + show_axes = True, + layout = 0, + figsize = (6,6), + return_phase_estimate = False, + return_rgb_images = False, + returnfig = False, + ): + """ + Plot the individual phase weight maps and residuals. + """ + + if phase_colors is None: + phase_colors = np.array(( + (1.0,0.0,0.0), + (0.0,0.8,1.0), + (0.0,0.8,0.0), + (1.0,0.0,1.0), + (0.0,0.4,1.0), + (1.0,0.8,0.0), + )) + + phase_weights = self.phase_weights.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:,:,a0][sub] /= self.int_total[sub] + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(phase_weights,axis=2)) + else: + scale = 1 + weight_threshold = weight_threshold * scale + + # init + im_all = np.zeros(( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1])) + im_rgb_all = np.zeros(( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + 3)) + + # phase weights over threshold + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + im = np.sum(phase_weights[:,:,sub],axis=2) + im_all[a0] = np.maximum(im - weight_threshold, 0) + + # estimate compositions + im_sum = np.sum(im_all, axis = 0) + sub = im_sum > 0.0 + for a0 in range(self.num_crystals): + im_all[a0][sub] /= im_sum[sub] + + for a1 in range(3): + im_rgb_all[a0,:,:,a1] = im_all[a0] * phase_colors[a0,a1] + + if plot_combine: + if crystal_inds_plot is None: + im_rgb = np.sum(im_rgb_all, axis = 0) + else: + im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis = 0) + + im_rgb = np.clip(im_rgb,0,1) + + fig,ax = plt.subplots(1,1,figsize=figsize) + ax.imshow( + im_rgb, + ) + ax.set_title( + 'Phase Maps', + fontsize = 16, ) - plt.show() + if not show_ticks: + ax.set_xticks([]) + ax.set_yticks([]) + if not show_axes: + ax.set_axis_off() + + else: + # plotting + if layout == 0: + fig,ax = plt.subplots( + 1, + self.num_crystals, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + elif layout == 1: + fig,ax = plt.subplots( + self.num_crystals, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + + for a0 in range(self.num_crystals): + + ax[a0].imshow( + im_rgb_all[a0], + ) + ax[a0].set_title( + self.crystal_names[a0], + fontsize = 16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # All possible returns + if return_phase_estimate: + if returnfig: + return im_all, fig, ax + else: + return im_all + elif return_rgb_images: + if plot_combine: + if returnfig: + return im_rgb, fig, ax + else: + return im_rgb + else: + if returnfig: + return im_rgb_all, fig, ax + else: + return im_rgb_all + else: + if returnfig: + return fig, ax + + + def plot_dominant_phase( + self, + rel_range = (0.0,1.0), + sigma = 0.0, + phase_colors = None, + figsize = (6,6), + ): + """ + Plot a combined figure showing the primary phase at each probe position. + Mask by the reliability index (best match minus 2nd best match). + """ + + if phase_colors is None: + phase_colors = np.array([ + [1.0,0.9,0.6], + [1,0,0], + [0,0.7,0], + [0,0.7,1], + [1,0,1], + ]) + + + # init arrays + scan_shape = self.phase_weights.shape[:2] + phase_map = np.zeros(scan_shape) + phase_corr = np.zeros(scan_shape) + phase_corr_2nd = np.zeros(scan_shape) + phase_sig = np.zeros((self.num_crystals,scan_shape[0],scan_shape[1])) + + # sum up phase weights by crystal type + for a0 in range(self.num_fits): + ind = self.crystal_identity[a0,0] + phase_sig[ind] += self.phase_weights[:,:,a0] + + # smoothing of the outputs + if sigma > 0.0: + for a0 in range(self.num_crystals): + phase_sig[a0] = gaussian_filter( + phase_sig[a0], + sigma = sigma, + mode = 'nearest', + ) + + + # find highest correlation score for each crystal and match index + for a0 in range(self.num_crystals): + sub = phase_sig[a0] > phase_corr + phase_map[sub] = a0 + phase_corr[sub] = phase_sig[a0][sub] + + # find the second correlation score for each crystal and match index + for a0 in range(self.num_crystals): + corr = phase_sig[a0].copy() + corr[phase_map==a0] = 0.0 + sub = corr > phase_corr_2nd + phase_corr_2nd[sub] = corr[sub] + + + + + + # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): + # sub = crystals[a0].orientation_map.corr[:,:,a1] > phase_corr + # phase_map[sub] = a0 + # phase_corr[sub] = crystals[a0].orientation_map.corr[:,:,a1][sub] + + # # find the second correlation score for each crystal and match index + # for a0 in range(len(crystals)): + # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): + # corr = crystals[a0].orientation_map.corr[:,:,a1].copy() + # corr[phase_map==a0] = 0.0 + # sub = corr > phase_corr_2nd + # phase_corr_2nd[sub] = corr[sub] + + # Estimate the reliability + phase_rel = phase_corr - phase_corr_2nd + + # # Generate a color plotting image + # c_vals = np.array([ + # [0.7,0.5,0.3], + # [1,0,0], + # [0,0.7,0], + # [0,0.7,1], + # [0,0.7,1], + # ]) + phase_scale = np.clip( + (phase_rel - rel_range[0]) / (rel_range[1] - rel_range[0]), + 0, + 1) + + + phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) + for a0 in range(self.num_crystals): + sub = phase_map==a0 + for a1 in range(3): + phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] + # normalize + # phase_rgb = np.clip( + # (phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), + # 0,1) + + + + fig,ax = plt.subplots(figsize=figsize) + ax.imshow( + phase_rgb, + # vmin = 0, + # vmax = 5, + # phase_rgb, + # phase_corr - phase_corr_2nd, + # cmap = 'turbo', + # vmin = 0, + # vmax = 3, + # cmap = 'gray', + ) + + return fig,ax + + + # def plot_all_phase_maps(self, map_scale_values=None, index=0): + # """ + # Visualize phase maps of dataset. + + # Args: + # map_scale_values (float): Value to scale correlations by + # """ + # phase_maps = [] + # if map_scale_values is None: + # map_scale_values = [1] * len(self.orientation_maps) + # corr_sum = np.sum( + # [ + # (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m]) + # for m in range(len(self.orientation_maps)) + # ] + # ) + # for m in range(len(self.orientation_maps)): + # phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum) + # show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno") + # return + + # def plot_phase_map(self, index=0, cmap=None): + # corr_array = np.dstack( + # [maps.corr[:, :, index] for maps in self.orientation_maps] + # ) + # best_corr_score = np.max(corr_array, axis=2) + # best_match_phase = [ + # np.where(corr_array[:, :, p] == best_corr_score, True, False) + # for p in range(len(self.orientation_maps)) + # ] - return + # if cmap is None: + # cm = plt.get_cmap("rainbow") + # cmap = [ + # cm(1.0 * i / len(self.orientation_maps)) + # for i in range(len(self.orientation_maps)) + # ] + + # fig, (ax) = plt.subplots(figsize=(6, 6)) + # ax.matshow( + # np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), + # cmap="gray", + # ) + # ax.axis("off") + + # for m in range(len(self.orientation_maps)): + # c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[ + # m + # ] + # cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10) + # ax.matshow( + # np.ma.array( + # self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m] + # ), + # cmap=cm, + # ) + # plt.show() + + # return # Potentially introduce a way to check best match out of all orientations in phase plan and plug into model # to quantify phase @@ -124,187 +1015,187 @@ def plot_phase_map(self, index=0, cmap=None): # ): # return - def quantify_phase( - self, - pointlistarray, - tolerance_distance=0.08, - method="nnls", - intensity_power=0, - mask_peaks=None, - ): - """ - Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. + # def quantify_phase( + # self, + # pointlistarray, + # tolerance_distance=0.08, + # method="nnls", + # intensity_power=0, + # mask_peaks=None, + # ): + # """ + # Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from + # Args: + # pointlisarray (pointlistarray): Pointlistarray to quantify phase of + # tolerance_distance (float): Distance allowed between a peak and match + # method (str): Numerical method used to quantify phase + # intensity_power (float): ... + # mask_peaks (list, optional): A pointer of which positions to mask peaks from - Details: - """ - if isinstance(pointlistarray, PointListArray): - phase_weights = np.zeros( - ( - pointlistarray.shape[0], - pointlistarray.shape[1], - np.sum([map.num_matches for map in self.orientation_maps]), - ) - ) - phase_residuals = np.zeros(pointlistarray.shape) - for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): - ( - _, - phase_weight, - phase_residual, - crystal_identity, - ) = self.quantify_phase_pointlist( - pointlistarray, - position=[Rx, Ry], - tolerance_distance=tolerance_distance, - method=method, - intensity_power=intensity_power, - mask_peaks=mask_peaks, - ) - phase_weights[Rx, Ry, :] = phase_weight - phase_residuals[Rx, Ry] = phase_residual - self.phase_weights = phase_weights - self.phase_residuals = phase_residuals - self.crystal_identity = crystal_identity - return - else: - return TypeError("pointlistarray must be of type pointlistarray.") - return + # Details: + # """ + # if isinstance(pointlistarray, PointListArray): + # phase_weights = np.zeros( + # ( + # pointlistarray.shape[0], + # pointlistarray.shape[1], + # np.sum([map.num_matches for map in self.orientation_maps]), + # ) + # ) + # phase_residuals = np.zeros(pointlistarray.shape) + # for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): + # ( + # _, + # phase_weight, + # phase_residual, + # crystal_identity, + # ) = self.quantify_phase_pointlist( + # pointlistarray, + # position=[Rx, Ry], + # tolerance_distance=tolerance_distance, + # method=method, + # intensity_power=intensity_power, + # mask_peaks=mask_peaks, + # ) + # phase_weights[Rx, Ry, :] = phase_weight + # phase_residuals[Rx, Ry] = phase_residual + # self.phase_weights = phase_weights + # self.phase_residuals = phase_residuals + # self.crystal_identity = crystal_identity + # return + # else: + # return TypeError("pointlistarray must be of type pointlistarray.") + # return - def quantify_phase_pointlist( - self, - pointlistarray, - position, - method="nnls", - tolerance_distance=0.08, - intensity_power=0, - mask_peaks=None, - ): - """ - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - position (tuple/list): Position of pointlist in pointlistarray - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from - - Returns: - pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - phase_weights (np.ndarray): Weights of each phase - phase_residuals (np.ndarray): Residuals - crystal_identity (list): List of lists, where the each entry represents the position in the - crystal and orientation match that is associated with the phase - weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - the first entry [0,0] in phase weights is associated with the first crystal - the first match within that crystal. [0,1] is the first crystal and the - second match within that crystal. - """ - # Things to add: - # 1. Better cost for distance from peaks in pointlists - # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? + # def quantify_phase_pointlist( + # self, + # pointlistarray, + # position, + # method="nnls", + # tolerance_distance=0.08, + # intensity_power=0, + # mask_peaks=None, + # ): + # """ + # Args: + # pointlisarray (pointlistarray): Pointlistarray to quantify phase of + # position (tuple/list): Position of pointlist in pointlistarray + # tolerance_distance (float): Distance allowed between a peak and match + # method (str): Numerical method used to quantify phase + # intensity_power (float): ... + # mask_peaks (list, optional): A pointer of which positions to mask peaks from - pointlist = pointlistarray.get_pointlist(position[0], position[1]) - pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0) - pointlist.remove(pl_mask) - # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in + # Returns: + # pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns + # phase_weights (np.ndarray): Weights of each phase + # phase_residuals (np.ndarray): Residuals + # crystal_identity (list): List of lists, where the each entry represents the position in the + # crystal and orientation match that is associated with the phase + # weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], + # the first entry [0,0] in phase weights is associated with the first crystal + # the first match within that crystal. [0,1] is the first crystal and the + # second match within that crystal. + # """ + # # Things to add: + # # 1. Better cost for distance from peaks in pointlists + # # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - if intensity_power == 0: - pl_intensities = np.ones(pointlist["intensity"].shape) - else: - pl_intensities = pointlist["intensity"] ** intensity_power - # Prepare matches for modeling - pointlist_peak_matches = [] - crystal_identity = [] - - for c in range(len(self.crystals)): - for m in range(self.orientation_maps[c].num_matches): - crystal_identity.append([c, m]) - phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape)) - bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - self.orientation_maps[c].get_orientation(position[0], position[1]), - ind_orientation=m, - ) - # Find the best match peak within tolerance_distance and add value in the right position - for d in range(pointlist["qx"].shape[0]): - distances = [] - for p in range(bragg_peaks_fit["qx"].shape[0]): - distances.append( - np.sqrt( - (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2 - + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2 - ) - ) - ind = np.where(distances == np.min(distances))[0][0] - - # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - if distances[ind] <= tolerance_distance: - ## Somewhere in this if statement is probably where better distances from the peak should be coded in - if ( - intensity_power == 0 - ): # This could potentially be a different intensity_power arg - phase_peak_match_intensities[d] = 1 ** ( - (tolerance_distance - distances[ind]) - / tolerance_distance - ) - else: - phase_peak_match_intensities[d] = bragg_peaks_fit[ - "intensity" - ][ind] ** ( - (tolerance_distance - distances[ind]) - / tolerance_distance - ) - else: - ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - continue - - pointlist_peak_matches.append(phase_peak_match_intensities) - pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - pointlist_peak_intensity_matches = ( - pointlist_peak_intensity_matches.reshape( - pl_intensities.shape[0], - pointlist_peak_intensity_matches.shape[-1], - ) - ) + # pointlist = pointlistarray.get_pointlist(position[0], position[1]) + # pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0) + # pointlist.remove(pl_mask) + # # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in - if len(pointlist["qx"]) > 0: - if mask_peaks is not None: - for i in range(len(mask_peaks)): - if mask_peaks[i] == None: # noqa: E711 - continue - inds_mask = np.where( - pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 - )[0] - for mask in range(len(inds_mask)): - pointlist_peak_intensity_matches[inds_mask[mask], i] = 0 + # if intensity_power == 0: + # pl_intensities = np.ones(pointlist["intensity"].shape) + # else: + # pl_intensities = pointlist["intensity"] ** intensity_power + # # Prepare matches for modeling + # pointlist_peak_matches = [] + # crystal_identity = [] - if method == "nnls": - phase_weights, phase_residuals = nnls( - pointlist_peak_intensity_matches, pl_intensities - ) + # for c in range(len(self.crystals)): + # for m in range(self.orientation_maps[c].num_matches): + # crystal_identity.append([c, m]) + # phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape)) + # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + # self.orientation_maps[c].get_orientation(position[0], position[1]), + # ind_orientation=m, + # ) + # # Find the best match peak within tolerance_distance and add value in the right position + # for d in range(pointlist["qx"].shape[0]): + # distances = [] + # for p in range(bragg_peaks_fit["qx"].shape[0]): + # distances.append( + # np.sqrt( + # (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2 + # + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2 + # ) + # ) + # ind = np.where(distances == np.min(distances))[0][0] - elif method == "lstsq": - phase_weights, phase_residuals, rank, singluar_vals = lstsq( - pointlist_peak_intensity_matches, pl_intensities, rcond=-1 - ) - phase_residuals = np.sum(phase_residuals) - else: - raise ValueError(method + " Not yet implemented. Try nnls or lstsq.") - else: - phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) - phase_residuals = np.NaN - return ( - pointlist_peak_intensity_matches, - phase_weights, - phase_residuals, - crystal_identity, - ) + # # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value + # if distances[ind] <= tolerance_distance: + # ## Somewhere in this if statement is probably where better distances from the peak should be coded in + # if ( + # intensity_power == 0 + # ): # This could potentially be a different intensity_power arg + # phase_peak_match_intensities[d] = 1 ** ( + # (tolerance_distance - distances[ind]) + # / tolerance_distance + # ) + # else: + # phase_peak_match_intensities[d] = bragg_peaks_fit[ + # "intensity" + # ][ind] ** ( + # (tolerance_distance - distances[ind]) + # / tolerance_distance + # ) + # else: + # ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled + # continue + + # pointlist_peak_matches.append(phase_peak_match_intensities) + # pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) + # pointlist_peak_intensity_matches = ( + # pointlist_peak_intensity_matches.reshape( + # pl_intensities.shape[0], + # pointlist_peak_intensity_matches.shape[-1], + # ) + # ) + + # if len(pointlist["qx"]) > 0: + # if mask_peaks is not None: + # for i in range(len(mask_peaks)): + # if mask_peaks[i] == None: # noqa: E711 + # continue + # inds_mask = np.where( + # pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 + # )[0] + # for mask in range(len(inds_mask)): + # pointlist_peak_intensity_matches[inds_mask[mask], i] = 0 + + # if method == "nnls": + # phase_weights, phase_residuals = nnls( + # pointlist_peak_intensity_matches, pl_intensities + # ) + + # elif method == "lstsq": + # phase_weights, phase_residuals, rank, singluar_vals = lstsq( + # pointlist_peak_intensity_matches, pl_intensities, rcond=-1 + # ) + # phase_residuals = np.sum(phase_residuals) + # else: + # raise ValueError(method + " Not yet implemented. Try nnls or lstsq.") + # else: + # phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) + # phase_residuals = np.NaN + # return ( + # pointlist_peak_intensity_matches, + # phase_weights, + # phase_residuals, + # crystal_identity, + # ) # def plot_peak_matches( # self, From 304e35b8688c45c16bc4edf02aff81f59b3ab3d9 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Apr 2024 08:06:21 -0700 Subject: [PATCH 10/74] adding single phase method --- py4DSTEM/process/diffraction/crystal_phase.py | 341 ++++++++++-------- 1 file changed, 200 insertions(+), 141 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 508a01c79..bc5f10721 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -78,20 +78,22 @@ def quantify_single_pattern( pointlistarray: PointListArray, xy_position = (0,0), corr_kernel_size = 0.04, + sigma_excitation_error = 0.02, + power_experiment = 0.25, + power_calculated = 0.25, + max_number_patterns = 3, + single_phase = False, allow_strain = True, - # corr_distance_scale = 1.0, + strain_iterations = 5, + strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, - max_number_phases = 1, - sigma_excitation_error = 0.04, - power_experiment = 0.25, - power_calculated = 0.25, plot_result = True, plot_only_nonzero_phases = True, plot_unmatched_peaks = False, plot_correlation_radius = False, scale_markers_experiment = 40, - scale_markers_calculated = 500, + scale_markers_calculated = 200, crystal_inds_plot = None, phase_colors = None, figsize = (10,7), @@ -161,6 +163,8 @@ def quantify_single_pattern( basis = np.zeros((intensity.shape[0], self.num_fits)) if allow_strain: m_strains = np.zeros((self.num_fits,2,2)) + m_strains[:,0,0] = 1.0 + m_strains[:,1,1] = 1.0 # kernel radius squared radius_max_2 = corr_kernel_size**2 @@ -202,56 +206,71 @@ def quantify_single_pattern( matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') if allow_strain: - # Initial peak pairing to find best-fit strain distortion - pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') - pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') - for a1 in range(bragg_peaks_fit.data.shape[0]): - dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ - + (bragg_peaks_fit.data['qy'][a1] - qy)**2 - ind_min = np.argmin(dist2) - val_min = dist2[ind_min] - - if val_min < radius_max_2: - pair_sub[a1] = True - pair_inds[a1] = ind_min - - # calculate best-fit strain tensor, weighted by the intensities. - # requires at least 4 peak pairs - if np.sum(pair_sub) >= 4: - pair_basis = np.vstack(( - qx[pair_inds[pair_sub]], - qy[pair_inds[pair_sub]], - )).T - pair_obs = np.vstack(( - bragg_peaks_fit.data['qx'][pair_sub], - bragg_peaks_fit.data['qy'][pair_sub], - )).T - - # weights - dists = np.sqrt( - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) - weights = np.sqrt( - int_fit[pair_sub] * intensity[pair_inds[pair_sub]] - ) * (1 - dists / corr_kernel_size) - - # wtrain tensor - m_strain = np.linalg.lstsq( - pair_basis * weights[:,None], - pair_obs * weights[:,None], - rcond = None, - )[0] - m_strains[a0] = m_strain - - # Transformed peak positions - qx_copy = qx.copy() - qy_copy = qy.copy() - qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] - - # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) - # print(np.round(dist,4)) + for a1 in range(strain_iterations): + # Initial peak pairing to find best-fit strain distortion + pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') + pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + pair_sub[a1] = True + pair_inds[a1] = ind_min + + # calculate best-fit strain tensor, weighted by the intensities. + # requires at least 4 peak pairs + if np.sum(pair_sub) >= 4: + # pair_obs = bragg_peaks_fit.data[['qx','qy']][pair_sub] + pair_basis = np.vstack(( + bragg_peaks_fit.data['qx'][pair_sub], + bragg_peaks_fit.data['qy'][pair_sub], + )).T + pair_obs = np.vstack(( + qx[pair_inds[pair_sub]], + qy[pair_inds[pair_sub]], + )).T + + # weights + dists = np.sqrt( + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) + weights = np.sqrt( + int_fit[pair_sub] * intensity[pair_inds[pair_sub]] + ) * (1 - dists / corr_kernel_size) + # weights = 1 - dists / corr_kernel_size + + # strain tensor + m_strain = np.linalg.lstsq( + pair_basis * weights[:,None], + pair_obs * weights[:,None], + rcond = None, + )[0] + + # Clamp strains to be within the user-specified limit + m_strain = np.clip( + m_strain, + np.eye(2) - strain_max, + np.eye(2) + strain_max, + ) + m_strains[a0] *= m_strain + + # Transformed peak positions + qx_copy = bragg_peaks_fit.data['qx'] + qy_copy = bragg_peaks_fit.data['qy'] + bragg_peaks_fit.data['qx'] = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + bragg_peaks_fit.data['qy'] = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + # qx_copy = qx.copy() + # qy_copy = qy.copy() + # qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + # qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) + # print(np.round(dist,4)) # Loop over all peaks, pair experiment to library for a1 in range(bragg_peaks_fit.data.shape[0]): @@ -289,25 +308,53 @@ def quantify_single_pattern( else: obs = intensity - # Solve for phase coefficients + # Solve for phase weight coefficients try: phase_weights = np.zeros(self.num_fits) - inds_solve = np.ones(self.num_fits,dtype='bool') - search = True - while search is True: - phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], - obs, + if single_phase: + # loop through each crystal structure and determine the best fit structure, + # which can contain multiple orienations. + crystal_res = np.zeros(self.num_crystals) + + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + + phase_weights_cand, phase_residual_cand = nnls( + basis[:,sub], + obs, + ) + phase_weights[sub] = phase_weights_cand + crystal_res[a0] = phase_residual_cand + + ind_best_fit = np.argmin(crystal_res) + phase_residual = crystal_res[ind_best_fit] + sub = np.logical_not( + self.crystal_identity[:,0] == ind_best_fit ) + phase_weights[sub] = 0.0 - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_phases: - phase_weights[inds_solve] = phase_weights_cand - phase_residual = phase_residual_cand - search = False - else: - inds = np.where(inds_solve)[0] - inds_solve[inds[np.argmin(phase_weights_cand)]] = False + # Estimate reliability as difference between best fit and 2nd best fit + crystal_res = np.sort(crystal_res) + phase_reliability = crystal_res[1] - crystal_res[0] + + else: + # Allow all crystals and orientation matches in the pattern + inds_solve = np.ones(self.num_fits,dtype='bool') + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_weights[inds_solve] = phase_weights_cand + phase_residual = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False except: phase_weights = np.zeros(self.num_fits) @@ -322,7 +369,7 @@ def quantify_single_pattern( c = self.crystal_identity[a0,0] m = self.crystal_identity[a0,1] line = '{:>12} {:>8} {:<12}'.format( - np.round(phase_weights[a0],decimals=2), + f'{phase_weights[a0]:.2f}', m, self.crystal_names[c] ) @@ -330,7 +377,17 @@ def quantify_single_pattern( print('\033[1m' + line + '\033[0m') else: print(line) - # print() + print('----------------------------') + line = '{:>12} {:>15}'.format( + f'{sum(phase_weights):.2f}', + 'fit total' + ) + print('\033[1m' + line + '\033[0m') + line = '{:>12} {:>15}'.format( + f'{phase_residual:.2f}', + 'fit residual' + ) + print(line) # Plotting if plot_result: @@ -380,7 +437,7 @@ def quantify_single_pattern( "family": "sans-serif", "fontweight": "normal", "color": "k", - "size": 14, + "size": 12, } if plot_correlation_radius: ax_leg.plot( @@ -422,6 +479,7 @@ def quantify_single_pattern( # )) mvals = ['v','^','<','>','d','s',] + count_leg = 0 for a0 in range(self.num_fits): c = self.crystal_identity[a0,0] m = self.crystal_identity[a0,1] @@ -431,83 +489,84 @@ def quantify_single_pattern( qx_fit = library_peaks[a0].data['qx'] qy_fit = library_peaks[a0].data['qy'] - # if allow_strain: - # m_strain = m_strains[a0] - # # Transformed peak positions - # qx_copy = qx_fit.copy() - # qy_copy = qy_fit.copy() - # qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - # qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + if allow_strain: + m_strain = m_strains[a0] + # Transformed peak positions + qx_copy = qx_fit.copy() + qy_copy = qy_fit.copy() + qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] int_fit = library_int[a0] matches_fit = library_matches[a0] if plot_only_nonzero_phases is False or phase_weights[a0] > 0: - if np.mod(m,2) == 0: + # if np.mod(m,2) == 0: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + if plot_unmatched_peaks: ax.scatter( - qy_fit[matches_fit], - qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], marker = mvals[c], facecolor = phase_colors[c,:], ) - if plot_unmatched_peaks: - ax.scatter( - qy_fit[np.logical_not(matches_fit)], - qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) - - # legend - ax_leg.scatter( - 0, - dx_leg*(a0+1), - s = 200, - marker = mvals[c], - facecolor = phase_colors[c,:], - ) - else: - ax.scatter( - qy_fit[matches_fit], - qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), - # facecolors = (1,1,1,0.5), - linewidth = 2, - ) - if plot_unmatched_peaks: - ax.scatter( - qy_fit[np.logical_not(matches_fit)], - qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (1,1,1,0.5), - linewidth = 2, - ) - - # legend - ax_leg.scatter( - 0, - dx_leg*(a0+1), - s = 200, - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), - # facecolors = (1,1,1,0.5), - ) - # legend text - ax_leg.text( - dy_leg, - (a0+1)*dx_leg, - self.crystal_names[c], - **text_params) + # legend + if m == 0: + ax_leg.text( + dy_leg, + (count_leg+1)*dx_leg, + self.crystal_names[c], + **text_params) + ax_leg.scatter( + 0, + (count_leg+1) * dx_leg, + s = 200, + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + count_leg += 1 + # else: + # ax.scatter( + # qy_fit[matches_fit], + # qx_fit[matches_fit], + # s = scale_markers_calculated * int_fit[matches_fit], + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # # facecolors = (1,1,1,0.5), + # linewidth = 2, + # ) + # if plot_unmatched_peaks: + # ax.scatter( + # qy_fit[np.logical_not(matches_fit)], + # qx_fit[np.logical_not(matches_fit)], + # s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (1,1,1,0.5), + # linewidth = 2, + # ) + + # # legend + # ax_leg.scatter( + # 0, + # dx_leg*(a0+1), + # s = 200, + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # # facecolors = (1,1,1,0.5), + # ) + # appearance @@ -531,7 +590,7 @@ def quantify_phase( allow_strain = True, include_false_positives = True, weight_false_positives = 1.0, - max_number_phases = 2, + max_number_patterns = 2, sigma_excitation_error = 0.02, power_experiment = 0.25, power_calculated = 0.25, @@ -571,7 +630,7 @@ def quantify_phase( # corr_distance_scale = corr_distance_scale, include_false_positives = include_false_positives, weight_false_positives = weight_false_positives, - max_number_phases = max_number_phases, + max_number_patterns = max_number_patterns, sigma_excitation_error = sigma_excitation_error, power_experiment = power_experiment, power_calculated = power_calculated, From 909b84f4140ca6c0f543df9e93fc5be2e0528dfc Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Apr 2024 08:32:47 -0700 Subject: [PATCH 11/74] updating plotting function --- py4DSTEM/process/diffraction/crystal_phase.py | 130 +++++++++--------- 1 file changed, 68 insertions(+), 62 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index bc5f10721..b997fb9d1 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -263,15 +263,6 @@ def quantify_single_pattern( bragg_peaks_fit.data['qx'] = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] bragg_peaks_fit.data['qy'] = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] - # qx_copy = qx.copy() - # qy_copy = qy.copy() - # qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - # qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] - - # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) - # print(np.round(dist,4)) - # Loop over all peaks, pair experiment to library for a1 in range(bragg_peaks_fit.data.shape[0]): dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ @@ -356,9 +347,33 @@ def quantify_single_pattern( inds = np.where(inds_solve)[0] inds_solve[inds[np.argmin(phase_weights_cand)]] = False + # Estimate the phase reliability + inds_solve = np.ones(self.num_fits,dtype='bool') + inds_solve[phase_weights > 1e-8] = False + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_residual_2nd = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + + # print(inds_solve) + # phase_weights_cand, phase_residual_cand = nnls( + # basis[:,inds_solve], + # obs, + # ) + phase_reliability = phase_residual_2nd - phase_residual + except: phase_weights = np.zeros(self.num_fits) phase_residual = np.sqrt(np.sum(intensity**2)) + phase_reliability = 0.0 if verbose: ind_max = np.argmax(phase_weights) @@ -578,22 +593,24 @@ def quantify_single_pattern( ax_leg.set_axis_off() if returnfig: - return phase_weights, phase_residual, int_total, fig, ax + return phase_weights, phase_residual, phase_reliability, int_total, fig, ax else: - return phase_weights, phase_residual, int_total + return phase_weights, phase_residual, phase_reliability, int_total def quantify_phase( self, pointlistarray: PointListArray, corr_kernel_size = 0.04, - # corr_distance_scale = 1.0, - allow_strain = True, - include_false_positives = True, - weight_false_positives = 1.0, - max_number_patterns = 2, sigma_excitation_error = 0.02, power_experiment = 0.25, power_calculated = 0.25, + max_number_patterns = 3, + single_phase = False, + allow_strain = True, + strain_iterations = 5, + strain_max = 0.02, + include_false_positives = True, + weight_false_positives = 1.0, progress_bar = True, ): """ @@ -610,36 +627,44 @@ def quantify_phase( pointlistarray.shape[0], pointlistarray.shape[1], )) + self.phase_reliability = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) self.int_total = np.zeros(( pointlistarray.shape[0], pointlistarray.shape[1], )) + self.single_phase = single_phase for rx, ry in tqdmnd( *pointlistarray.shape, - desc="Matching Orientations", + desc="Quantifying Phase", unit=" PointList", disable=not progress_bar, ): # calculate phase weights - phase_weights, phase_residual, int_peaks = self.quantify_single_pattern( + phase_weights, phase_residual, phase_reliability, int_peaks = self.quantify_single_pattern( pointlistarray = pointlistarray, xy_position = (rx,ry), corr_kernel_size = corr_kernel_size, - allow_strain = allow_strain, - # corr_distance_scale = corr_distance_scale, - include_false_positives = include_false_positives, - weight_false_positives = weight_false_positives, - max_number_patterns = max_number_patterns, sigma_excitation_error = sigma_excitation_error, power_experiment = power_experiment, power_calculated = power_calculated, + max_number_patterns = max_number_patterns, + single_phase = single_phase, + allow_strain = allow_strain, + strain_iterations = strain_iterations, + strain_max = strain_max, + include_false_positives = include_false_positives, + weight_false_positives = weight_false_positives, plot_result = False, verbose = False, returnfig = False, ) self.phase_weights[rx,ry] = phase_weights self.phase_residuals[rx,ry] = phase_residual + self.phase_reliability[rx,ry] = phase_reliability self.int_total[rx,ry] = int_peaks @@ -877,7 +902,7 @@ def plot_phase_maps( def plot_dominant_phase( self, - rel_range = (0.0,1.0), + reliability_range = (0.0,1.0), sigma = 0.0, phase_colors = None, figsize = (6,6), @@ -918,52 +943,33 @@ def plot_dominant_phase( mode = 'nearest', ) - # find highest correlation score for each crystal and match index for a0 in range(self.num_crystals): sub = phase_sig[a0] > phase_corr phase_map[sub] = a0 phase_corr[sub] = phase_sig[a0][sub] - # find the second correlation score for each crystal and match index - for a0 in range(self.num_crystals): - corr = phase_sig[a0].copy() - corr[phase_map==a0] = 0.0 - sub = corr > phase_corr_2nd - phase_corr_2nd[sub] = corr[sub] - - - - + if self.single_phase: + phase_scale = np.clip( + (self.phase_reliability - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + 0, + 1) - # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): - # sub = crystals[a0].orientation_map.corr[:,:,a1] > phase_corr - # phase_map[sub] = a0 - # phase_corr[sub] = crystals[a0].orientation_map.corr[:,:,a1][sub] + else: - # # find the second correlation score for each crystal and match index - # for a0 in range(len(crystals)): - # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): - # corr = crystals[a0].orientation_map.corr[:,:,a1].copy() - # corr[phase_map==a0] = 0.0 - # sub = corr > phase_corr_2nd - # phase_corr_2nd[sub] = corr[sub] - - # Estimate the reliability - phase_rel = phase_corr - phase_corr_2nd - - # # Generate a color plotting image - # c_vals = np.array([ - # [0.7,0.5,0.3], - # [1,0,0], - # [0,0.7,0], - # [0,0.7,1], - # [0,0.7,1], - # ]) - phase_scale = np.clip( - (phase_rel - rel_range[0]) / (rel_range[1] - rel_range[0]), - 0, - 1) + # find the second correlation score for each crystal and match index + for a0 in range(self.num_crystals): + corr = phase_sig[a0].copy() + corr[phase_map==a0] = 0.0 + sub = corr > phase_corr_2nd + phase_corr_2nd[sub] = corr[sub] + + # Estimate the reliability + phase_rel = phase_corr - phase_corr_2nd + phase_scale = np.clip( + (phase_rel - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + 0, + 1) phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) From 932fe5689b24ccac60bfbd48fc77a1fcb2c9377d Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Apr 2024 13:48:34 -0700 Subject: [PATCH 12/74] Adding new weights for cost function --- py4DSTEM/process/diffraction/crystal_phase.py | 73 ++++++++++++------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index b997fb9d1..45219f10f 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -84,7 +84,7 @@ def quantify_single_pattern( max_number_patterns = 3, single_phase = False, allow_strain = True, - strain_iterations = 5, + strain_iterations = 3, strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, @@ -270,17 +270,33 @@ def quantify_single_pattern( ind_min = np.argmin(dist2) val_min = dist2[ind_min] - if val_min < radius_max_2: - # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size - # weight = 1 + corr_distance_scale * \ - # np.sqrt(dist2[ind_min]) / corr_kernel_size - # basis[ind_min,a0] = weight * int_fit[a1] - basis[ind_min,a0] = int_fit[a1] - if plot_result: + if include_false_positives: + weight = np.clip(1 - np.sqrt(dist2[ind_min]) / corr_kernel_size,0,1) + basis[ind_min,a0] = int_fit[a1] * weight + unpaired_peaks.append([ + a0, + int_fit[a1] * (1 - weight), + ]) + if weight > 1e-8 and plot_result: matches[a1] = True - elif include_false_positives: - # unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) - unpaired_peaks.append([a0,int_fit[a1]]) + else: + if val_min < radius_max_2: + basis[ind_min,a0] = int_fit[a1] + if plot_result: + matches[a1] = True + + + # if val_min < radius_max_2: + # # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + # # weight = 1 + corr_distance_scale * \ + # # np.sqrt(dist2[ind_min]) / corr_kernel_size + # # basis[ind_min,a0] = weight * int_fit[a1] + # basis[ind_min,a0] = int_fit[a1] + # if plot_result: + # matches[a1] = True + # elif include_false_positives: + # # unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) + # unpaired_peaks.append([a0,int_fit[a1]]) if plot_result: library_peaks.append(bragg_peaks_fit) @@ -347,34 +363,39 @@ def quantify_single_pattern( inds = np.where(inds_solve)[0] inds_solve[inds[np.argmin(phase_weights_cand)]] = False + # Estimate the phase reliability inds_solve = np.ones(self.num_fits,dtype='bool') inds_solve[phase_weights > 1e-8] = False - search = True - while search is True: + + if np.all(inds_solve == False): + phase_reliability = 0.0 + else: + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_residual_2nd = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + phase_weights_cand, phase_residual_cand = nnls( basis[:,inds_solve], obs, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: - phase_residual_2nd = phase_residual_cand - search = False - else: - inds = np.where(inds_solve)[0] - inds_solve[inds[np.argmin(phase_weights_cand)]] = False - - # print(inds_solve) - # phase_weights_cand, phase_residual_cand = nnls( - # basis[:,inds_solve], - # obs, - # ) - phase_reliability = phase_residual_2nd - phase_residual + phase_reliability = phase_residual_2nd - phase_residual except: phase_weights = np.zeros(self.num_fits) phase_residual = np.sqrt(np.sum(intensity**2)) phase_reliability = 0.0 + if verbose: ind_max = np.argmax(phase_weights) # print() From e306c5dcad9632bbb98b2552e4f5a803a641b9d2 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 11 Apr 2024 09:48:34 -0700 Subject: [PATCH 13/74] change to strain iteration default --- py4DSTEM/process/diffraction/crystal_phase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 45219f10f..f6d916a95 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -628,7 +628,7 @@ def quantify_phase( max_number_patterns = 3, single_phase = False, allow_strain = True, - strain_iterations = 5, + strain_iterations = 3, strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, From 217d71f60761bac311a28e2ce9bf8e08046b92c0 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 28 Apr 2024 14:43:04 -0700 Subject: [PATCH 14/74] precession electron diffraction added --- py4DSTEM/process/diffraction/crystal.py | 153 ++++++++++++++++++------ 1 file changed, 115 insertions(+), 38 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index e0fe59eee..6667e6cf5 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -712,39 +712,65 @@ def generate_diffraction_pattern( foil_normal_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, sigma_excitation_error: float = 0.02, tol_excitation_error_mult: float = 3, - tol_intensity: float = 1e-4, + tol_intensity: float = 1e-5, k_max: Optional[float] = None, + precession_angle_degrees = None, keep_qz=False, return_orientation_matrix=False, ): """ Generate a single diffraction pattern, return all peaks as a pointlist. - Args: - orientation (Orientation): an Orientation class object - ind_orientation If input is an Orientation class object with multiple orientations, - this input can be used to select a specific orientation. - - orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions. - zone_axis_lattice (array): (3,) projection direction in lattice indices - proj_x_lattice (array): (3,) x-axis direction in lattice indices - zone_axis_cartesian (array): (3,) cartesian projection direction - proj_x_cartesian (array): (3,) cartesian projection direction - - foil_normal: 3 element foil normal - set to None to use zone_axis - proj_x_axis (np float vector): 3 element vector defining image x axis (vertical) - accel_voltage (float): Accelerating voltage in Volts. If not specified, - we check to see if crystal already has voltage specified. - sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): Maximum scattering vector - keep_qz (bool): Flag to return out-of-plane diffraction vectors - return_orientation_matrix (bool): Return the orientation matrix + Parameters + ---------- - Returns: - bragg_peaks (PointList): list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] - orientation_matrix (array): 3x3 orientation matrix (optional) + orientation (Orientation) + an Orientation class object + ind_orientation + If input is an Orientation class object with multiple orientations, + this input can be used to select a specific orientation. + + orientation_matrix (3,3) numpy.array + orientation matrix, where columns represent projection directions. + zone_axis_lattice (3,) numpy.array + projection direction in lattice indices + proj_x_lattice (3,) numpy.array + x-axis direction in lattice indices + zone_axis_cartesian (3,) numpy.array + cartesian projection direction + proj_x_cartesian (3,) numpy.array + cartesian projection direction + + foil_normal + 3 element foil normal - set to None to use zone_axis + proj_x_axis (3,) numpy.array + 3 element vector defining image x axis (vertical) + accel_voltage (float) + Accelerating voltage in Volts. If not specified, + we check to see if crystal already has voltage specified. + + sigma_excitation_error (float) + sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms + tol_excitation_error_mult (float) + tolerance in units of sigma for s_g inclusion + tol_intensity (numpy float) + tolerance in intensity units for inclusion of diffraction spots + k_max (float) + Maximum scattering vector + precession_angle_degrees (float) + Precession angle for library calculation. Set to None for no precession. + keep_qz (bool) + Flag to return out-of-plane diffraction vectors + return_orientation_matrix (bool) + Return the orientation matrix + + Returns + ---------- + bragg_peaks (PointList) + list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] + orientation_matrix (array) + 3x3 orientation matrix (optional) + """ if not (hasattr(self, "wavelength") and hasattr(self, "accel_voltage")): @@ -779,17 +805,27 @@ def generate_diffraction_pattern( # Calculate excitation errors if foil_normal is None: - sg = self.excitation_errors(g) + sg = self.excitation_errors( + g, + precession_angle_degrees = precession_angle_degrees, + ) else: foil_normal = ( orientation_matrix.T @ (-1 * foil_normal[:, None] / np.linalg.norm(foil_normal)) ).ravel() - sg = self.excitation_errors(g, foil_normal) + sg = self.excitation_errors( + g, + foil_normal = foil_normal, + precession_angle_degrees = precession_angle_degrees, + ) # Threshold for inclusion in diffraction pattern sg_max = sigma_excitation_error * tol_excitation_error_mult - keep = np.abs(sg) <= sg_max + if precession_angle_degrees is None: + keep = np.abs(sg) <= sg_max + else: + keep = np.min(np.abs(sg),axis=1) <= sg_max # Maximum scattering angle cutoff if k_max is not None: @@ -799,9 +835,17 @@ def generate_diffraction_pattern( g_diff = g[:, keep] # Diffracted peak intensities and labels - g_int = self.struct_factors_int[keep] * np.exp( - (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) - ) + if precession_angle_degrees is None: + g_int = self.struct_factors_int[keep] * np.exp( + (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) + ) + else: + g_int = self.struct_factors_int[keep] * np.mean( + np.exp( + (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) + ), + axis = 1, + ) hkl = self.hkl[:, keep] # Intensity tolerance @@ -1323,21 +1367,54 @@ def excitation_errors( self, g, foil_normal=None, + precession_angle_degrees=None, + precession_steps=180, ): """ Calculate the excitation errors, assuming k0 = [0, 0, -1/lambda]. If foil normal is not specified, we assume it is [0,0,-1]. + + Precession is currently implemented using numerical integration. """ - if foil_normal is None: - return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( - 2 - 2 * self.wavelength * g[2, :] - ) + + if precession_angle_degrees is None: + if foil_normal is None: + return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( + 2 - 2 * self.wavelength * g[2, :] + ) + else: + return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( + 2 * self.wavelength * np.sum(g * foil_normal[:, None], axis=0) + - 2 * foil_normal[2] + ) + else: - return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( - 2 * self.wavelength * np.sum(g * foil_normal[:, None], axis=0) - - 2 * foil_normal[2] + t = np.deg2rad(precession_angle_degrees) + p = np.linspace( + 0, + 2.0*np.pi, + precession_steps, + endpoint = False, + ) + if foil_normal is None: + foil_normal = np.array((0.0,0.0,-1.0)) + + k = np.reshape( + (-1/self.wavelength) * np.vstack(( + np.sin(t)*np.cos(p), + np.sin(t)*np.sin(p), + np.cos(t)*np.ones(p.size), + )), + (3,1,p.size), ) + term1 = np.sum( (g[:,:,None] + k) * foil_normal[:,None,None], axis=0) + term2 = np.sum( (g[:,:,None] + 2*k) * g[:,:,None], axis=0) + sg = np.sqrt(term1**2 - term2) - term1 + + return sg + + def calculate_bragg_peak_histogram( self, bragg_peaks, From 264deba5e5c9f52554962539c6d5108b207b807a Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 28 Apr 2024 14:55:52 -0700 Subject: [PATCH 15/74] New default values for diffraction sigma and intensity threshold --- py4DSTEM/process/diffraction/crystal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 6667e6cf5..9694111c9 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -710,7 +710,7 @@ def generate_diffraction_pattern( zone_axis_cartesian: Optional[np.ndarray] = None, proj_x_cartesian: Optional[np.ndarray] = None, foil_normal_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, - sigma_excitation_error: float = 0.02, + sigma_excitation_error: float = 0.01, tol_excitation_error_mult: float = 3, tol_intensity: float = 1e-5, k_max: Optional[float] = None, From ba5e0a215c3c5fb59c01729b4352f72fa71a24e1 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 28 Apr 2024 15:07:15 -0700 Subject: [PATCH 16/74] Adding precession to orientation plans, update docstring --- py4DSTEM/process/diffraction/crystal_ACOM.py | 101 ++++++++++++------- 1 file changed, 67 insertions(+), 34 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index fc6e691c3..f81ad53cc 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -29,6 +29,7 @@ def orientation_plan( corr_kernel_size: float = 0.08, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling + precession_angle_degrees = None, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -41,39 +42,60 @@ def orientation_plan( """ Calculate the rotation basis arrays for an SO(3) rotation correlogram. - Args: - zone_axis_range (float): Row vectors give the range for zone axis orientations. - If user specifies 2 vectors (2x3 array), we start at [0,0,1] - to make z-x-z rotation work. - If user specifies 3 vectors (3x3 array), plan will span these vectors. - Setting to 'full' as a string will use a hemispherical range. - Setting to 'half' as a string will use a quarter sphere range. - Setting to 'fiber' as a string will make a spherical cap around a given vector. - Setting to 'auto' will use pymatgen to determine the point group symmetry - of the structure and choose an appropriate zone_axis_range - angle_step_zone_axis (float): Approximate angular step size for zone axis search [degrees] - angle_coarse_zone_axis (float): Coarse step size for zone axis search [degrees]. Setting to - None uses the same value as angle_step_zone_axis. - angle_refine_range (float): Range of angles to use for zone axis refinement. Setting to - None uses same value as angle_coarse_zone_axis. - - angle_step_in_plane (float): Approximate angular step size for in-plane rotation [degrees] - accel_voltage (float): Accelerating voltage for electrons [Volts] - corr_kernel_size (float): Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] - radial_power (float): Power for scaling the correlation intensity as a function of the peak radius - intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity - calculate_correlation_array (bool): Set to false to skip calculating the correlation array. - This is useful when we only want the angular range / rotation matrices. - tol_peak_delete (float): Distance to delete peaks for multiple matches. - Default is kernel_size * 0.5 - tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] - fiber_axis (float): (3,) vector specifying the fiber axis - fiber_angles (float): (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] - cartesian_directions (bool): When set to true, all zone axes and projection directions - are specified in Cartesian directions. - figsize (float): (2,) vector giving the figure size - CUDA (bool): Use CUDA for the Fourier operations. - progress_bar (bool): If false no progress bar is displayed + Parameters + ---------- + zone_axis_range (float): + Row vectors give the range for zone axis orientations. + If user specifies 2 vectors (2x3 array), we start at [0,0,1] + to make z-x-z rotation work. + If user specifies 3 vectors (3x3 array), plan will span these vectors. + Setting to 'full' as a string will use a hemispherical range. + Setting to 'half' as a string will use a quarter sphere range. + Setting to 'fiber' as a string will make a spherical cap around a given vector. + Setting to 'auto' will use pymatgen to determine the point group symmetry + of the structure and choose an appropriate zone_axis_range + angle_step_zone_axis (float): + Approximate angular step size for zone axis search [degrees] + angle_coarse_zone_axis (float): + Coarse step size for zone axis search [degrees]. Setting to + None uses the same value as angle_step_zone_axis. + angle_refine_range (float): + Range of angles to use for zone axis refinement. Setting to + None uses same value as angle_coarse_zone_axis. + + angle_step_in_plane (float): + Approximate angular step size for in-plane rotation [degrees] + accel_voltage (float): + Accelerating voltage for electrons [Volts] + corr_kernel_size (float): + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + radial_power (float): + Power for scaling the correlation intensity as a function of the peak radius + intensity_power (float): + Power for scaling the correlation intensity as a function of the peak intensity + precession_angle_degrees (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + calculate_correlation_array (bool): + Set to false to skip calculating the correlation array. + This is useful when we only want the angular range / rotation matrices. + tol_peak_delete (float): + Distance to delete peaks for multiple matches. + Default is kernel_size * 0.5 + tol_distance (float): + Distance tolerance for radial shell assignment [1/Angstroms] + fiber_axis (float): + (3,) vector specifying the fiber axis + fiber_angles (float): + (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] + cartesian_directions (bool): + When set to true, all zone axes and projection directions + are specified in Cartesian directions. + figsize (float): + (2,) vector giving the figure size + CUDA (bool): + Use CUDA for the Fourier operations. + progress_bar (bool): + If false no progress bar is displayed, """ # Check to make sure user has calculated the structure factors if needed @@ -717,7 +739,18 @@ def orientation_plan( ): # reciprocal lattice spots and excitation errors g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all - sg = self.excitation_errors(g) + if precession_angle_degrees is None: + sg = self.excitation_errors(g) + else: + sg = np.min( + np.abs( + self.excitation_errors( + g, + precession_angle_degrees = precession_angle_degrees, + ), + ), + axis = 1, + ) # Keep only points that will contribute to this orientation plan slice keep = np.abs(sg) < self.orientation_kernel_size From 3ee1adf8d6fa0b039221e1b1516f9f31ca7525fb Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 6 May 2024 14:04:35 -0700 Subject: [PATCH 17/74] adding precession to orientation plans --- py4DSTEM/process/diffraction/crystal_ACOM.py | 50 +++++++++++++------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index f81ad53cc..f7406b590 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -27,9 +27,10 @@ def orientation_plan( angle_step_in_plane: float = 2.0, accel_voltage: float = 300e3, corr_kernel_size: float = 0.08, + sigma_excitation_error: float = 0.01, + precession_angle_degrees = None, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling - precession_angle_degrees = None, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -68,13 +69,17 @@ def orientation_plan( accel_voltage (float): Accelerating voltage for electrons [Volts] corr_kernel_size (float): - Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error (float): + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + radial_power (float): Power for scaling the correlation intensity as a function of the peak radius intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity - precession_angle_degrees (float) - Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). calculate_correlation_array (bool): Set to false to skip calculating the correlation array. This is useful when we only want the angular range / rotation matrices. @@ -108,6 +113,7 @@ def orientation_plan( # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) + self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) if tol_peak_delete is None: self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5 else: @@ -739,25 +745,35 @@ def orientation_plan( ): # reciprocal lattice spots and excitation errors g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all - if precession_angle_degrees is None: - sg = self.excitation_errors(g) - else: - sg = np.min( - np.abs( - self.excitation_errors( - g, - precession_angle_degrees = precession_angle_degrees, - ), - ), - axis = 1, - ) + # if precession_angle_degrees is None: + sg = self.excitation_errors(g) + # else: + # sg = np.min( + # np.abs( + # self.excitation_errors( + # g, + # precession_angle_degrees = precession_angle_degrees, + # ), + # ), + # axis = 1, + # ) # Keep only points that will contribute to this orientation plan slice - keep = np.abs(sg) < self.orientation_kernel_size + keep = np.logical_and( + np.abs(sg) < self.orientation_kernel_size, + self.orientation_shell_index >= 0, + ) # in-plane rotation angle phi = np.arctan2(g[1, :], g[0, :]) + # calculate intensity of spots + # if precession_angle_degrees is None: + # Ig = np.exp(sg**2/(-2*sigma_excitation_error**2)) + # else: + # pass + + # Loop over all peaks for a1 in np.arange(self.g_vec_all.shape[1]): ind_radial = self.orientation_shell_index[a1] From 19058a5b4cd1664352cf72ad82f4d3d2f18fc6eb Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 6 May 2024 16:29:02 -0700 Subject: [PATCH 18/74] precession added to orientation plans! --- py4DSTEM/process/diffraction/crystal_ACOM.py | 104 ++++++++++++------- py4DSTEM/process/diffraction/crystal_viz.py | 1 + 2 files changed, 66 insertions(+), 39 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index f7406b590..960461669 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -113,7 +113,11 @@ def orientation_plan( # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) - self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) + if precession_angle_degrees is None: + self.orientation_precession_angle_degrees = None + else: + self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) + self.orientation_precession_angle = np.deg2rad(np.asarray(precession_angle_degrees)) if tol_peak_delete is None: self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5 else: @@ -764,45 +768,67 @@ def orientation_plan( self.orientation_shell_index >= 0, ) - # in-plane rotation angle - phi = np.arctan2(g[1, :], g[0, :]) - # calculate intensity of spots - # if precession_angle_degrees is None: - # Ig = np.exp(sg**2/(-2*sigma_excitation_error**2)) - # else: - # pass - - - # Loop over all peaks - for a1 in np.arange(self.g_vec_all.shape[1]): - ind_radial = self.orientation_shell_index[a1] - - if keep[a1] and ind_radial >= 0: - # 2D orientation plan - self.orientation_ref[a0, ind_radial, :] += ( - np.power(self.orientation_shell_radii[ind_radial], radial_power) - * np.power(self.struct_factors_int[a1], intensity_power) - * np.maximum( - 1 - - np.sqrt( - sg[a1] ** 2 - + ( - ( - np.mod( - self.orientation_gamma - phi[a1] + np.pi, - 2 * np.pi, - ) - - np.pi - ) - * self.orientation_shell_radii[ind_radial] - ) - ** 2 - ) - / self.orientation_kernel_size, - 0, - ) - ) + if precession_angle_degrees is None: + Ig = np.exp(sg[keep]**2/(-2*sigma_excitation_error**2)) + else: + # precession extension + prec = np.cos(np.linspace(0,2*np.pi,90,endpoint=False)) + dsg = np.tan(self.orientation_precession_angle) * np.sum(g[:2,keep]**2,axis=0) + Ig = np.mean(np.exp((sg[keep,None] + dsg[:,None]*prec[None,:])**2 \ + / (-2*sigma_excitation_error**2)), axis = 1) + + # in-plane rotation angle + phi = np.arctan2(g[1, keep], g[0, keep]) + phi_ind = phi / self.orientation_gamma[1] # step size of annular bins + phi_floor = np.floor(phi_ind).astype('int') + dphi = phi_ind - phi_floor + + # write intensities into orientation plan slice + radial_inds = self.orientation_shell_index[keep] + self.orientation_ref[a0, radial_inds, phi_floor] += \ + (1-dphi) * \ + np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ + np.power(self.orientation_shell_radii[radial_inds], radial_power) + self.orientation_ref[a0, radial_inds, np.mod(phi_floor+1,self.orientation_in_plane_steps)] += \ + dphi * \ + np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ + np.power(self.orientation_shell_radii[radial_inds], radial_power) + + + # # Loop over all peaks + # for a1 in np.arange(self.g_vec_all.shape[1]): + # if keep[a1]: + + + # for a1 in np.arange(self.g_vec_all.shape[1]): + # ind_radial = self.orientation_shell_index[a1] + + # if keep[a1] and ind_radial >= 0: + # # 2D orientation plan + # self.orientation_ref[a0, ind_radial, :] += ( + # np.power(self.orientation_shell_radii[ind_radial], radial_power) + # * np.power(self.struct_factors_int[a1], intensity_power) + # * np.maximum( + # 1 + # - np.sqrt( + # sg[a1] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma - phi[a1] + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * self.orientation_shell_radii[ind_radial] + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ) + # ) orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) if orientation_ref_norm > 0: diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 47df2e6ca..fdb313cbb 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -848,6 +848,7 @@ def plot_orientation_plan( bragg_peaks = self.generate_diffraction_pattern( orientation_matrix=self.orientation_rotation_matrices[index_plot, :], sigma_excitation_error=self.orientation_kernel_size / 3, + precession_angle_degrees = self.orientation_precession_angle_degrees, ) plot_diffraction_pattern( From 3e656b6df981908e96904616b6c77416528dcb4c Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 6 May 2024 19:20:31 -0700 Subject: [PATCH 19/74] Fiddling with matching routines --- py4DSTEM/process/diffraction/crystal_ACOM.py | 38 ++++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 960461669..bdff28de2 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -27,7 +27,7 @@ def orientation_plan( angle_step_in_plane: float = 2.0, accel_voltage: float = 300e3, corr_kernel_size: float = 0.08, - sigma_excitation_error: float = 0.01, + sigma_excitation_error: float = 0.02, precession_angle_degrees = None, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling @@ -1034,12 +1034,7 @@ def match_single_pattern( if np.any(sub): im_polar[ind_radial, :] = np.sum( - np.power(radius, self.orientation_radial_power) - * np.power( - np.maximum(intensity[sub, None], 0.0), - self.orientation_intensity_power, - ) - * np.maximum( + np.maximum( 1 - np.sqrt( dqr[sub, None] ** 2 @@ -1062,6 +1057,35 @@ def match_single_pattern( ), axis=0, ) + # im_polar[ind_radial, :] = np.sum( + # np.power(radius, self.orientation_radial_power) + # * np.power( + # np.maximum(intensity[sub, None], 0.0), + # self.orientation_intensity_power, + # ) + # * np.maximum( + # 1 + # - np.sqrt( + # dqr[sub, None] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma[None, :] + # - qphi[sub, None] + # + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * radius + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ), + # axis=0, + # ) # Determine the RMS signal from im_polar for the first match. # Note that we use scaling slightly below RMS so that following matches From e768a517285007dc3b2195dc471c95ef6e8d5a5f Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 7 May 2024 05:18:36 -0700 Subject: [PATCH 20/74] adding intensity scaling back into experiment --- py4DSTEM/process/diffraction/crystal_ACOM.py | 40 ++++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index bdff28de2..b1f64aa58 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -29,8 +29,9 @@ def orientation_plan( corr_kernel_size: float = 0.08, sigma_excitation_error: float = 0.02, precession_angle_degrees = None, - radial_power: float = 1.0, - intensity_power: float = 0.25, # New default intensity power scaling + power_radial: float = 1.0, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -76,10 +77,12 @@ def orientation_plan( precession_angle_degrees (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - radial_power (float): + power_radial (float): Power for scaling the correlation intensity as a function of the peak radius - intensity_power (float): - Power for scaling the correlation intensity as a function of the peak intensity + power_intensity (float): + Power for scaling the correlation intensity as a function of simulated peak intensity + power_intensity_experiment (float): + Power for scaling the correlation intensity as a function of experimental peak intensity calculate_correlation_array (bool): Set to false to skip calculating the correlation array. This is useful when we only want the angular range / rotation matrices. @@ -136,8 +139,9 @@ def orientation_plan( self.wavelength = electron_wavelength_angstrom(self.accel_voltage) # store the radial and intensity scaling to use later for generating test patterns - self.orientation_radial_power = radial_power - self.orientation_intensity_power = intensity_power + self.orientation_power_radial = power_radial + self.orientation_power_intensity = power_intensity + self.orientation_power_intensity_experiment = power_intensity_experiment # Calculate the ratio between coarse and fine refinement if angle_coarse_zone_axis is not None: @@ -788,12 +792,12 @@ def orientation_plan( radial_inds = self.orientation_shell_index[keep] self.orientation_ref[a0, radial_inds, phi_floor] += \ (1-dphi) * \ - np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ - np.power(self.orientation_shell_radii[radial_inds], radial_power) + np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ + np.power(self.orientation_shell_radii[radial_inds], power_radial) self.orientation_ref[a0, radial_inds, np.mod(phi_floor+1,self.orientation_in_plane_steps)] += \ dphi * \ - np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ - np.power(self.orientation_shell_radii[radial_inds], radial_power) + np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ + np.power(self.orientation_shell_radii[radial_inds], power_radial) # # Loop over all peaks @@ -807,8 +811,8 @@ def orientation_plan( # if keep[a1] and ind_radial >= 0: # # 2D orientation plan # self.orientation_ref[a0, ind_radial, :] += ( - # np.power(self.orientation_shell_radii[ind_radial], radial_power) - # * np.power(self.struct_factors_int[a1], intensity_power) + # np.power(self.orientation_shell_radii[ind_radial], power_radial) + # * np.power(self.struct_factors_int[a1], power_intensity) # * np.maximum( # 1 # - np.sqrt( @@ -1034,7 +1038,11 @@ def match_single_pattern( if np.any(sub): im_polar[ind_radial, :] = np.sum( - np.maximum( + np.power( + np.maximum(intensity[sub, None], 0.0), + self.orientation_power_intensity_experiment, + ) + * np.maximum( 1 - np.sqrt( dqr[sub, None] ** 2 @@ -1058,10 +1066,10 @@ def match_single_pattern( axis=0, ) # im_polar[ind_radial, :] = np.sum( - # np.power(radius, self.orientation_radial_power) + # np.power(radius, self.orientation_power_radial) # * np.power( # np.maximum(intensity[sub, None], 0.0), - # self.orientation_intensity_power, + # self.orientation_power_intensity, # ) # * np.maximum( # 1 From 8792a7249485b2e04e1b08f3ee309799fa79a21e Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 8 May 2024 11:35:25 -0700 Subject: [PATCH 21/74] adding normalization testing --- py4DSTEM/process/diffraction/crystal_ACOM.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index b1f64aa58..0c8776010 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -834,6 +834,8 @@ def orientation_plan( # ) # ) + # normalization + # self.orientation_ref[a0, :, :] -= np.mean(self.orientation_ref[a0, :, :]) orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) if orientation_ref_norm > 0: self.orientation_ref[a0, :, :] /= orientation_ref_norm @@ -1095,6 +1097,10 @@ def match_single_pattern( # axis=0, # ) + + # # normalization + # self.im_polar -= np.mean(im_polar) + # Determine the RMS signal from im_polar for the first match. # Note that we use scaling slightly below RMS so that following matches # don't have higher correlating scores than previous matches. From 04bbbaf1faf2a2313b9db99f774a073896ee51c4 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 9 May 2024 09:32:40 -0700 Subject: [PATCH 22/74] Making kernel a Gaussian function instead of cone --- py4DSTEM/process/diffraction/crystal_ACOM.py | 64 ++++++++++++++------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 0c8776010..3acce396c 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1039,34 +1039,62 @@ def match_single_pattern( sub = dqr < self.orientation_kernel_size if np.any(sub): - im_polar[ind_radial, :] = np.sum( + + dist = im_polar[ind_radial, :] = np.sum( np.power( np.maximum(intensity[sub, None], 0.0), self.orientation_power_intensity_experiment, ) - * np.maximum( - 1 - - np.sqrt( - dqr[sub, None] ** 2 - + ( - ( - np.mod( - self.orientation_gamma[None, :] - - qphi[sub, None] - + np.pi, - 2 * np.pi, - ) - - np.pi + * np.exp( + (dqr[sub, None] ** 2 + + ( + ( + np.mod( + self.orientation_gamma[None, :] + - qphi[sub, None] + + np.pi, + 2 * np.pi, ) - * radius + - np.pi ) - ** 2 + * radius ) - / self.orientation_kernel_size, - 0, + ** 2) + / (-2*self.orientation_kernel_size**2) ), axis=0, ) + + # im_polar[ind_radial, :] = np.sum( + # np.power( + # np.maximum(intensity[sub, None], 0.0), + # self.orientation_power_intensity_experiment, + # ) + # * np.maximum( + # 1 + # - np.sqrt( + # dqr[sub, None] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma[None, :] + # - qphi[sub, None] + # + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * radius + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ), + # axis=0, + # ) + + # im_polar[ind_radial, :] = np.sum( # np.power(radius, self.orientation_power_radial) # * np.power( From e8a02ae20ae97515dd877dd5c918958eaed712e6 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 9 May 2024 14:59:58 -0700 Subject: [PATCH 23/74] tweaking multiple match finder --- py4DSTEM/process/diffraction/crystal_ACOM.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 3acce396c..be8ae1f50 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -116,6 +116,7 @@ def orientation_plan( # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) + self.orientation_sigma_excitation_error = sigma_excitation_error if precession_angle_degrees is None: self.orientation_precession_angle_degrees = None else: @@ -837,6 +838,7 @@ def orientation_plan( # normalization # self.orientation_ref[a0, :, :] -= np.mean(self.orientation_ref[a0, :, :]) orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) + # orientation_ref_norm = np.sum(self.orientation_ref[a0, :, :]) if orientation_ref_norm > 0: self.orientation_ref[a0, :, :] /= orientation_ref_norm @@ -1327,6 +1329,7 @@ def match_single_pattern( np.clip( np.sum( self.orientation_vecs + # self.orientation_vecs * np.array([1,-1,-1])[None,:] * self.orientation_vecs[inds_previous[a0], :], axis=1, ), @@ -1625,7 +1628,8 @@ def match_single_pattern( bragg_peaks_fit = self.generate_diffraction_pattern( orientation, ind_orientation=match_ind, - sigma_excitation_error=self.orientation_kernel_size, + sigma_excitation_error = self.orientation_sigma_excitation_error, + precession_angle_degrees = self.orientation_precession_angle_degrees, ) remove = np.zeros_like(qx, dtype="bool") From daae2258aaa9ae43f056356318fcba57fb5f2d4e Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 17 May 2024 10:28:04 -0700 Subject: [PATCH 24/74] fix typo --- py4DSTEM/process/diffraction/crystal_ACOM.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index be8ae1f50..74e842817 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1042,7 +1042,7 @@ def match_single_pattern( if np.any(sub): - dist = im_polar[ind_radial, :] = np.sum( + im_polar[ind_radial, :] = np.sum( np.power( np.maximum(intensity[sub, None], 0.0), self.orientation_power_intensity_experiment, @@ -1128,8 +1128,8 @@ def match_single_pattern( # ) - # # normalization - # self.im_polar -= np.mean(im_polar) + # normalization + # im_polar -= np.mean(im_polar) # Determine the RMS signal from im_polar for the first match. # Note that we use scaling slightly below RMS so that following matches From 4734fc9286a25090460ca6cebe639171f5c547e7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 10:08:42 -0700 Subject: [PATCH 25/74] updating README, now that tutorials have been updated --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3fe6cc745..d62845d4c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ > :warning: **py4DSTEM version 0.14 update** :warning: Warning: this is a major update and we expect some workflows to break. You can still install previous versions of py4DSTEM [as discussed here](#legacyinstall) -> :warning: **Phase retrieval refactor version 0.14.9** :warning: Warning: The phase-retrieval modules in py4DSTEM (DPC, parallax, and ptychography) underwent a major refactor in version 0.14.9 and as such older tutorial notebooks will not work as expected. Notably, class names have been pruned to remove the trailing "Reconstruction" (`DPCReconstruction` -> `DPC` etc.), and regularization functions have dropped the `_iter` suffix (and are instead specified as boolean flags). We are working on updating the tutorial notebooks to reflect these changes. In the meantime, there's some more information in the relevant pull request [here](https://github.com/py4dstem/py4DSTEM/pull/597#issuecomment-1890325568). +> :warning: **Phase retrieval refactor version 0.14.9** :warning: Warning: The phase-retrieval modules in py4DSTEM (DPC, parallax, and ptychography) underwent a major refactor in version 0.14.9 and as such older tutorial notebooks will not work as expected. Notably, class names have been pruned to remove the trailing "Reconstruction" (`DPCReconstruction` -> `DPC` etc.), and regularization functions have dropped the `_iter` suffix (and are instead specified as boolean flags). See the [updated tutorials](https://github.com/py4dstem/py4DSTEM_tutorials) for more information. ![py4DSTEM logo](/images/py4DSTEM_logo.png) From a87ab8591f6c5cebaa8a4e8e6217b2d5e058dc1b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 17:38:12 -0700 Subject: [PATCH 26/74] better position initialization , single-slice --- py4DSTEM/process/phase/phase_base_class.py | 38 ++++++++----------- .../process/phase/ptychographic_methods.py | 20 +++++++--- .../process/phase/singleslice_ptychography.py | 21 +++++++--- 3 files changed, 46 insertions(+), 33 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 22fda4ac9..742013f1e 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1864,42 +1864,34 @@ def _calculate_scan_positions_in_pixels( else: raise ValueError() - if transpose: - x = (x - np.ptp(x) / 2) / sampling[1] - y = (y - np.ptp(y) / 2) / sampling[0] - else: - x = (x - np.ptp(x) / 2) / sampling[0] - y = (y - np.ptp(y) / 2) / sampling[1] x, y = np.meshgrid(x, y, indexing="ij") if positions_offset_ang is not None: - if transpose: - x += positions_offset_ang[0] / sampling[1] - y += positions_offset_ang[1] / sampling[0] - else: - x += positions_offset_ang[0] / sampling[0] - y += positions_offset_ang[1] / sampling[1] + x += positions_offset_ang[0] + y += positions_offset_ang[1] if positions_mask is not None: x = x[positions_mask] y = y[positions_mask] - else: - positions -= np.mean(positions, axis=0) - x = positions[:, 0] / sampling[1] - y = positions[:, 1] / sampling[0] + + positions = np.stack((x.ravel(), y.ravel()), axis=-1) if rotation_angle is not None: - x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin( - rotation_angle - ) + y * np.cos(rotation_angle) + tf = AffineTransform(angle=rotation_angle) + positions = tf(positions, positions.mean(0)) if transpose: - positions = np.array([y.ravel(), x.ravel()]).T - else: - positions = np.array([x.ravel(), y.ravel()]).T + positions = np.flip(positions, 1) + sampling = sampling[::-1] + + # ensure positive + positions -= np.min(positions, axis=0).clip(-np.inf, 0) - positions -= np.min(positions, axis=0) + # finally, switch to pixels + positions[:, 0] /= sampling[0] + positions[:, 1] /= sampling[1] + # top-left padding if object_padding_px is None: float_padding = region_of_interest_shape / 2 object_padding_px = (float_padding, float_padding) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 283ddb1ba..25a7f7785 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -48,14 +48,24 @@ def _initialize_object( xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if object_type == "potential": _object = xp.zeros((p, q), dtype=xp.float32) elif object_type == "complex": diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index cc7a5865d..b220ba741 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -88,6 +88,9 @@ class SingleslicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A verbose: bool, optional @@ -124,6 +127,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", @@ -177,6 +181,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -206,6 +211,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -266,6 +272,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -443,15 +451,18 @@ def preprocess( self._object_shape = self._object.shape - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() From bb8bdba03b978160910e698aadbb17b1e547a4b4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 18:04:13 -0700 Subject: [PATCH 27/74] add mixed, multislice better positions --- .../mixedstate_multislice_ptychography.py | 21 +++++++--- .../process/phase/mixedstate_ptychography.py | 21 +++++++--- .../process/phase/multislice_ptychography.py | 21 +++++++--- .../process/phase/ptychographic_methods.py | 40 ++++++++++++++----- 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index b22d0a0bb..d82a37eb4 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -107,6 +107,9 @@ class MixedstateMultislicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A theta_x: float @@ -159,6 +162,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, theta_x: float = 0, theta_y: float = 0, @@ -245,6 +249,7 @@ def __init__( self._object_type = object_type self._positions_mask = positions_mask self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._verbose = verbose self._preprocessed = False @@ -278,6 +283,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -348,6 +354,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -524,15 +532,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 436338555..7bbadf114 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -96,6 +96,9 @@ class MixedstatePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A positions_mask: np.ndarray, optional @@ -127,6 +130,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_type: str = "complex", positions_mask: np.ndarray = None, @@ -194,6 +198,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -224,6 +229,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -294,6 +300,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -469,15 +477,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 6de8e7970..87e3c1fe4 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -99,6 +99,9 @@ class MultislicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A theta_x: float @@ -149,6 +152,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, theta_x: float = None, theta_y: float = None, @@ -220,6 +224,7 @@ def __init__( self._object_type = object_type self._positions_mask = positions_mask self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._verbose = verbose self._preprocessed = False @@ -252,6 +257,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -322,6 +328,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -498,15 +506,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 25a7f7785..b9eae9385 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -412,14 +412,24 @@ def _initialize_object( xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if object_type == "potential": _object = xp.zeros((num_slices, p, q), dtype=xp.float32) elif object_type == "complex": @@ -868,14 +878,24 @@ def _initialize_object( # explicit read-only self attributes up-front xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if main_tilt_axis == "vertical": _object = xp.zeros((q, p, q), dtype=xp.float32) From 92f35e1675b4ea033cf71a7662009ddaf9f565a2 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 21:11:10 -0700 Subject: [PATCH 28/74] magnetic positions, absolute --- .../process/phase/magnetic_ptychography.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 19f306188..58c826b87 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -101,6 +101,9 @@ class MagneticPtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A verbose: bool, optional @@ -138,6 +141,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_type: str = "complex", verbose: bool = True, @@ -189,6 +193,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -219,6 +224,7 @@ def preprocess( progress_bar: bool = True, object_fov_mask: np.ndarray = True, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -286,6 +292,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -610,14 +618,17 @@ def preprocess( self._positions_px_all, dtype=xp_storage.float32 ) - for index in range(self._num_measurements): - idx_start = self._cum_probes_per_measurement[index] - idx_end = self._cum_probes_per_measurement[index + 1] + if center_positions_in_fov: + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] - positions_px = self._positions_px_all[idx_start:idx_end] - positions_px_com = positions_px.mean(0) - positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 - self._positions_px_all[idx_start:idx_end] = positions_px.copy() + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= ( + positions_px_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_all[idx_start:idx_end] = positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() From c1557bb50724e7c84fabb2f258376ec2b47ed6f5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 29 May 2024 17:17:45 -0700 Subject: [PATCH 29/74] add xmcd class --- py4DSTEM/process/phase/__init__.py | 1 + .../phase/xray_magnetic_ptychography.py | 1849 +++++++++++++++++ 2 files changed, 1850 insertions(+) create mode 100644 py4DSTEM/process/phase/xray_magnetic_ptychography.py diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index ecfeaa1d2..500a7cfea 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -12,5 +12,6 @@ from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer +from py4DSTEM.process.phase.xray_magnetic_ptychography import XRayMagneticPtychography # fmt: on diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py new file mode 100644 index 000000000..77e1cb835 --- /dev/null +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -0,0 +1,1849 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely x-ray magnetic ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class XRayMagneticPtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Iterative X-Ray Magnetic Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed electrostatic dimensions : (Px,Py) + Reconstructed magnetic dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + datacube: Sequence[DataCube] + Tuple of input 4D diffraction pattern intensities + energy: float + The electron energy of the wave functions in eV + magnetic_contribution_sign: str, optional + One of '-+', '-0+', '0+' + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad objects with + If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (2,Px,Py) + If None, initialized to 1.0j for complex objects and 0.0 for potential objects + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px + positions_offset_ang: np.ndarray, optional + Offset of positions in A + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_magnetic_contribution_sign",) + + def __init__( + self, + energy: float, + datacube: Sequence[DataCube] = None, + magnetic_contribution_sign: str = "-+", + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, + positions_offset_ang: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "xray_magnetic_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "complex": + raise NotImplementedError() + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._magnetic_contribution_sign = magnetic_contribution_sign + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: Sequence[np.ndarray] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + progress_bar: bool = True, + object_fov_mask: np.ndarray = True, + crop_patterns: bool = False, + center_positions_in_fov: bool = True, + store_initial_arrays: bool = True, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._magnetic_contribution_sign == "-+": + self._recon_mode = 0 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic contribution sign in first meaurement assumed to be negative.\n" + "Magnetic contribution sign in second meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "-0+": + self._recon_mode = 1 + self._num_measurements = 3 + magnetic_contribution_msg = ( + "Magnetic contribution sign in first meaurement assumed to be negative.\n" + "Magnetic contribution assumed to be zero in second meaurement.\n" + "Magnetic contribution sign in third meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "0+": + self._recon_mode = 2 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic contribution assumed to be zero in first meaurement.\n" + "Magnetic contribution sign in second meaurement assumed to be positive." + ) + else: + raise ValueError( + f"magnetic_contribution_sign must be either '-+', '-0+', or '0+', not {self._magnetic_contribution_sign}" + ) + + if self._verbose: + warnings.warn( + magnetic_contribution_msg, + UserWarning, + ) + + if len(self._datacube) != self._num_measurements: + raise ValueError( + f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." + ) + + dc_shapes = [dc.shape for dc in self._datacube] + if dc_shapes.count(dc_shapes[0]) != self._num_measurements: + raise ValueError("datacube intensities must be the same size.") + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if force_com_measured is None: + force_com_measured = [None] * self._num_measurements + + if self._scan_positions is None: + self._scan_positions = [None] * self._num_measurements + + if self._positions_offset_ang is None: + self._positions_offset_ang = [None] * self._num_measurements + + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + + if progress_bar: + # turn off verbosity to play nice with tqdm + verbose = self._verbose + self._verbose = False + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="measurement", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first measurement + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured[index], + ) + + # estimate rotation / transpose using first measurement + if index == 0: + # silence warnings to play nice with progress bar + verbose = self._verbose + self._verbose = False + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + _com_x, + _com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + com_measured_x, + com_measured_y, + com_normalized_x, + com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=False, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + self._verbose = verbose + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + self._positions_offset_ang[index], + ) + + if progress_bar: + # reset verbosity + self._verbose = verbose + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + ) + + if self._object is None: + self._object = xp.full((2,) + obj.shape, obj) + else: + self._object = obj + + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + if center_positions_in_fov: + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= ( + positions_px_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + # initialize object_fov_mask + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + probe_overlap, + extent=extent, + cmap="gray", + ) + ax2.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_xlim((extent[0], extent[1])) + ax2.set_ylim((extent[2], extent[3])) + ax2.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = xp.empty( + (self._num_measurements,) + shifted_probes.shape, dtype=current_object.dtype + ) + object_patches[0] = current_object[ + 0, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + object_patches[1] = current_object[ + 1, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + overlap_base = shifted_probes * object_patches[0] + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + overlap = overlap_base * xp.exp(-1j * object_patches[1]) + case (0, 1) | (1, 2) | (2, 1): # forward + overlap = overlap_base * xp.exp(1j * object_patches[1]) + case (1, 1) | (2, 0): # neutral + overlap = overlap_base + case _: + raise ValueError() + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_conj = xp.conj(shifted_probes) # P* + electrostatic_conj = xp.conj(object_patches[0]) # C* = exp(-i c) + + probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) + probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( + probe_electrostatic_abs**2, + positions_px, + ) + probe_electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 + + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 + ) + + probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + + if not fix_probe: + electrostatic_magnetic_abs = xp.abs( + object_patches[0] * xp.exp(1j * object_patches[1]) + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_magnetic_normalization)) + ** 2 + ) + + if self._recon_mode > 0: + electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_normalization = xp.sum( + electrostatic_abs**2, + axis=0, + ) + electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + ) + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + + magnetic_exp = xp.exp(1j * xp.conj(object_patches[1])) + + # P* exp(i M*) + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * magnetic_exp * exit_waves, + positions_px, + ) + + # i exp(i M*) C* P* + magnetic_update = self._sum_overlapping_patches_bincounts( + 1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # exp(i M*) C* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * magnetic_exp * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (0, 1) | (1, 2) | (2, 1): # forward + + magnetic_exp = xp.exp(-1j * xp.conj(object_patches[1])) + + # P* exp(-i M*) + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * magnetic_exp * exit_waves, + positions_px, + ) + + # -i exp (-i M*) P* + magnetic_update = self._sum_overlapping_patches_bincounts( + -1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # exp(-i M*) C* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * magnetic_exp * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (1, 1) | (2, 0): # neutral + probe_abs = xp.abs(shifted_probes) + probe_normalization = self._sum_overlapping_patches_bincounts( + probe_abs**2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + # P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_normalization + ) + + if not fix_probe: + # V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_normalization + ) + + case _: + raise ValueError() + + return current_object, current_probe + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma_e, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_lowpass_m, + q_highpass_e, + q_highpass_m, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + **kwargs, + ): + """MagneticObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object[0] = self._object_gaussian_constraint( + current_object[0], gaussian_filter_sigma_e, False + ) + current_object[1] = self._object_gaussian_constraint( + current_object[1], gaussian_filter_sigma_m, False + ) + if butterworth_filter: + current_object[0] = self._object_butterworth_constraint( + current_object[0], + q_lowpass_e, + q_highpass_e, + butterworth_order, + ) + current_object[1] = self._object_butterworth_constraint( + current_object[1], + q_lowpass_m, + q_highpass_m, + butterworth_order, + ) + if tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], tv_denoise_weight, tv_denoise_inner_iter + ) + + # amplitude threshold + # current_object[0] = self._object_threshold_constraint( + # current_object[0], False + # ) + + return current_object + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma_e: float = None, + gaussian_filter_sigma_m: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass_e: float = None, + q_lowpass_m: float = None, + q_highpass_e: float = None, + q_highpass_m: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + detector_fourier_mask: np.ndarray = None, + store_iterations: bool = False, + collective_measurement_updates: bool = True, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma_e: float + Standard deviation of gaussian kernel for electrostatic object in A + gaussian_filter_sigma_m: float + Standard deviation of gaussian kernel for magnetic object in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass_e: float + Cut-off frequency in A^-1 for low-pass filtering electrostatic object + q_lowpass_m: float + Cut-off frequency in A^-1 for low-pass filtering magnetic object + q_highpass_e: float + Cut-off frequency in A^-1 for high-pass filtering electrostatic object + q_highpass_m: float + Cut-off frequency in A^-1 for high-pass filtering magnetic object + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + collective_measurement_updates: bool + if True perform collective updates for all measurements + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if use_projection_scheme: + raise NotImplementedError( + "Magnetic ptychography is currently only implemented for gradient descent." + ) + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if object_type is not None: + self._switch_object_type(object_type) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is not None: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + if gaussian_filter_sigma_m is None: + gaussian_filter_sigma_m = gaussian_filter_sigma_e + + if q_lowpass_m is None: + q_lowpass_m = q_lowpass_e + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + # randomize + measurement_indices = np.arange(self._num_measurements) + np.random.shuffle(measurement_indices) + + for measurement_index in measurement_indices: + self._active_measurement_index = measurement_index + + measurement_error = 0.0 + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme=use_projection_scheme, + projection_a=projection_a, + projection_b=projection_b, + projection_c=projection_c, + ) + + # adjoint operator + object_update, _probe = self._adjoint( + self._object.copy(), + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + object_update -= self._object + + # position correction + if not fix_positions and a0 > 0: + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + ) + + measurement_error += batch_error + + if collective_measurement_updates: + collective_object += object_update + else: + self._object += object_update + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (12, 5)) + cmap_real = kwargs.pop("cmap_real", "PiYG") + cmap_imag = kwargs.pop("cmap_imag", "PuOr") + chroma_boost = kwargs.pop("chroma_boost", 1) + + # get scaled arrays + probe = self._return_single_probe() + obj = self.object_cropped + + _, _, _vmax_real = return_scaled_histogram_ordering(obj[1].real) + vmin_real = kwargs.pop("vmin_real", -_vmax_real) + vmax_real = kwargs.pop("vmax_real", _vmax_real) + + _, _, _vmax_imag = return_scaled_histogram_ordering(obj[1].imag) + vmin_imag = kwargs.pop("vmin_imag", -_vmax_imag) + vmax_imag = kwargs.pop("vmax_imag", _vmax_imag) + + extent = [ + 0, + self.sampling[1] * obj.shape[2], + self.sampling[0] * obj.shape[1], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object_real + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[1].real, + extent=extent, + cmap=cmap_real, + vmin=vmin_real, + vmax=vmax_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_imag + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1].imag, + extent=extent, + cmap=cmap_imag, + vmin=vmin_imag, + vmax=vmax_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imaginary magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + ax = fig.add_subplot(spec[0, 2]) + if plot_fourier_probe: + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title("Reconstructed probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + # Object_real + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[1].real, + extent=extent, + cmap=cmap_real, + vmin=vmin_real, + vmax=vmax_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_imag + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1].imag, + extent=extent, + cmap=cmap_imag, + vmin=vmin_imag, + vmax=vmax_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imaginary magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[1, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + @property + def object_cropped(self): + """Cropped and rotated object""" + avg_pos = self._return_average_positions() + cropped_e = self._crop_rotate_object_fov(self._object[0], positions_px=avg_pos) + cropped_m = self._crop_rotate_object_fov(self._object[1], positions_px=avg_pos) + + return np.array([cropped_e, cropped_m]) From 7ed3d2e680b71969ab5ff27025b7bb4e8f5c1283 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 29 May 2024 21:06:12 -0700 Subject: [PATCH 30/74] xmcd init, adjoint typos --- .../phase/xray_magnetic_ptychography.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 77e1cb835..a40dd0f8a 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -602,7 +602,7 @@ def preprocess( ) if self._object is None: - self._object = xp.full((2,) + obj.shape, obj) + self._object = xp.stack((obj, xp.zeros_like(obj)), 0) else: self._object = obj @@ -819,9 +819,9 @@ def _overlap_projection( match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse - overlap = overlap_base * xp.exp(-1j * object_patches[1]) + overlap = overlap_base * xp.exp(-1.0j * object_patches[1]) case (0, 1) | (1, 2) | (2, 1): # forward - overlap = overlap_base * xp.exp(1j * object_patches[1]) + overlap = overlap_base * xp.exp(1.0j * object_patches[1]) case (1, 1) | (2, 0): # neutral overlap = overlap_base case _: @@ -874,7 +874,7 @@ def _gradient_descent_adjoint( xp = self._xp probe_conj = xp.conj(shifted_probes) # P* - electrostatic_conj = xp.conj(object_patches[0]) # C* = exp(-i c) + electrostatic_conj = xp.conj(object_patches[0]) # C* probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( @@ -887,7 +887,7 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 ) - probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) + probe_magnetic_abs = xp.abs(shifted_probes * xp.exp(1.0j * object_patches[1])) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( probe_magnetic_abs**2, positions_px, @@ -900,7 +900,7 @@ def _gradient_descent_adjoint( if not fix_probe: electrostatic_magnetic_abs = xp.abs( - object_patches[0] * xp.exp(1j * object_patches[1]) + object_patches[0] * xp.exp(1.0j * object_patches[1]) ) electrostatic_magnetic_normalization = xp.sum( electrostatic_magnetic_abs**2, @@ -928,7 +928,7 @@ def _gradient_descent_adjoint( match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse - magnetic_exp = xp.exp(1j * xp.conj(object_patches[1])) + magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) # P* exp(i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( @@ -938,7 +938,7 @@ def _gradient_descent_adjoint( # i exp(i M*) C* P* magnetic_update = self._sum_overlapping_patches_bincounts( - 1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + 1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, positions_px, ) @@ -961,7 +961,7 @@ def _gradient_descent_adjoint( case (0, 1) | (1, 2) | (2, 1): # forward - magnetic_exp = xp.exp(-1j * xp.conj(object_patches[1])) + magnetic_exp = xp.exp(-1.0j * xp.conj(object_patches[1])) # P* exp(-i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( @@ -971,7 +971,7 @@ def _gradient_descent_adjoint( # -i exp (-i M*) P* magnetic_update = self._sum_overlapping_patches_bincounts( - -1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + -1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, positions_px, ) From 936df0e1ff9f04a65631fdae92c236ed7950c67e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 29 May 2024 22:56:06 -0700 Subject: [PATCH 31/74] more careful normalizations --- .../phase/xray_magnetic_ptychography.py | 114 ++++++++++++------ 1 file changed, 77 insertions(+), 37 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index a40dd0f8a..48ec32a49 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -887,49 +887,22 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 ) - probe_magnetic_abs = xp.abs(shifted_probes * xp.exp(1.0j * object_patches[1])) - probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2, - positions_px, - ) - probe_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 - + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 - ) + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - object_patches[0] * xp.exp(1.0j * object_patches[1]) - ) - electrostatic_magnetic_normalization = xp.sum( - electrostatic_magnetic_abs**2, - axis=0, - ) - electrostatic_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_magnetic_normalization)) - ** 2 - ) + magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) - if self._recon_mode > 0: - electrostatic_abs = xp.abs(object_patches[0]) - electrostatic_normalization = xp.sum( - electrostatic_abs**2, - axis=0, + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, ) - electrostatic_normalization = 1 / xp.sqrt( + probe_magnetic_normalization = 1 / xp.sqrt( 1e-16 - + ((1 - normalization_min) * electrostatic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 ) - match (self._recon_mode, self._active_measurement_index): - case (0, 0) | (1, 0): # reverse - - magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) - # P* exp(i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( probe_conj * magnetic_exp * exit_waves, @@ -950,6 +923,28 @@ def _gradient_descent_adjoint( ) if not fix_probe: + + electrostatic_magnetic_abs = xp.abs( + object_patches[0] * magnetic_exp + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ( + (1 - normalization_min) + * electrostatic_magnetic_normalization + ) + ** 2 + + ( + normalization_min + * xp.max(electrostatic_magnetic_normalization) + ) + ** 2 + ) + # exp(i M*) C* current_probe += step_size * ( xp.sum( @@ -963,6 +958,17 @@ def _gradient_descent_adjoint( magnetic_exp = xp.exp(-1.0j * xp.conj(object_patches[1])) + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + # P* exp(-i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( probe_conj * magnetic_exp * exit_waves, @@ -983,6 +989,28 @@ def _gradient_descent_adjoint( ) if not fix_probe: + + electrostatic_magnetic_abs = xp.abs( + object_patches[0] * magnetic_exp + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ( + (1 - normalization_min) + * electrostatic_magnetic_normalization + ) + ** 2 + + ( + normalization_min + * xp.max(electrostatic_magnetic_normalization) + ) + ** 2 + ) + # exp(-i M*) C* current_probe += step_size * ( xp.sum( @@ -1015,6 +1043,18 @@ def _gradient_descent_adjoint( ) if not fix_probe: + + electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_normalization = xp.sum( + electrostatic_abs**2, + axis=0, + ) + electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + ) + # V* current_probe += step_size * ( xp.sum( From 8095fa06368840f79c614f996d8b9a5411068fdd Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 30 May 2024 17:42:05 -0700 Subject: [PATCH 32/74] switched to real-imag electronic contribution --- .../phase/xray_magnetic_ptychography.py | 77 +++++++++---------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 48ec32a49..1912c93b1 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -602,7 +602,8 @@ def preprocess( ) if self._object is None: - self._object = xp.stack((obj, xp.zeros_like(obj)), 0) + # complex zeros instead of ones, since we store pre-exponential terms + self._object = xp.zeros((2,) + obj.shape, dtype=obj.dtype) else: self._object = obj @@ -815,7 +816,7 @@ def _overlap_projection( 1, vectorized_patch_indices_row, vectorized_patch_indices_col ] - overlap_base = shifted_probes * object_patches[0] + overlap_base = shifted_probes * xp.exp(1.0j * object_patches[0]) match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse @@ -874,9 +875,9 @@ def _gradient_descent_adjoint( xp = self._xp probe_conj = xp.conj(shifted_probes) # P* - electrostatic_conj = xp.conj(object_patches[0]) # C* + electrostatic_conj = xp.exp(-1.0j * xp.conj(object_patches[0])) # exp[-i c] - probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) + probe_electrostatic_abs = xp.abs(probe_conj * electrostatic_conj) probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( probe_electrostatic_abs**2, positions_px, @@ -890,9 +891,9 @@ def _gradient_descent_adjoint( match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse - magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) + magnetic_conj = xp.exp(1.0j * xp.conj(object_patches[1])) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( probe_magnetic_abs**2, positions_px, @@ -903,15 +904,19 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 ) - # P* exp(i M*) + # - i * exp(i m*) * exp(-i c*) * P electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_exp * exit_waves, + -1.0j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves, positions_px, ) - # i exp(i M*) C* P* + # i * exp(i m*) * exp(-i c*) * P magnetic_update = self._sum_overlapping_patches_bincounts( - 1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + 1.0j * magnetic_conj * electrostatic_conj * probe_conj * exit_waves, positions_px, ) @@ -925,7 +930,7 @@ def _gradient_descent_adjoint( if not fix_probe: electrostatic_magnetic_abs = xp.abs( - object_patches[0] * magnetic_exp + electrostatic_conj * magnetic_conj ) electrostatic_magnetic_normalization = xp.sum( electrostatic_magnetic_abs**2, @@ -945,10 +950,10 @@ def _gradient_descent_adjoint( ** 2 ) - # exp(i M*) C* + # exp(i m*) * exp(-i c*) current_probe += step_size * ( xp.sum( - electrostatic_conj * magnetic_exp * exit_waves, + magnetic_conj * electrostatic_conj * exit_waves, axis=0, ) * electrostatic_magnetic_normalization @@ -956,9 +961,9 @@ def _gradient_descent_adjoint( case (0, 1) | (1, 2) | (2, 1): # forward - magnetic_exp = xp.exp(-1.0j * xp.conj(object_patches[1])) + magnetic_conj = xp.exp(-1.0j * xp.conj(object_patches[1])) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( probe_magnetic_abs**2, positions_px, @@ -969,29 +974,25 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 ) - # P* exp(-i M*) - electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_exp * exit_waves, - positions_px, - ) - - # -i exp (-i M*) P* - magnetic_update = self._sum_overlapping_patches_bincounts( - -1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + # - i * exp(-i m*) * exp(-i c*) * P + update = self._sum_overlapping_patches_bincounts( + -1.0j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves, positions_px, ) - current_object[0] += ( - step_size * electrostatic_update * probe_magnetic_normalization - ) + current_object[0] += step_size * update * probe_magnetic_normalization current_object[1] += ( - step_size * magnetic_update * probe_electrostatic_normalization + step_size * update * probe_electrostatic_normalization ) if not fix_probe: electrostatic_magnetic_abs = xp.abs( - object_patches[0] * magnetic_exp + electrostatic_conj * magnetic_conj ) electrostatic_magnetic_normalization = xp.sum( electrostatic_magnetic_abs**2, @@ -1011,16 +1012,17 @@ def _gradient_descent_adjoint( ** 2 ) - # exp(-i M*) C* + # exp(i m*) * exp(-i c*) current_probe += step_size * ( xp.sum( - electrostatic_conj * magnetic_exp * exit_waves, + magnetic_conj * electrostatic_conj * exit_waves, axis=0, ) * electrostatic_magnetic_normalization ) case (1, 1) | (2, 0): # neutral + probe_abs = xp.abs(shifted_probes) probe_normalization = self._sum_overlapping_patches_bincounts( probe_abs**2, @@ -1032,9 +1034,9 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_normalization)) ** 2 ) - # P* + # -i exp(-i c*) * P* electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves, + -1.0j * electrostatic_conj * probe_conj * exit_waves, positions_px, ) @@ -1044,7 +1046,7 @@ def _gradient_descent_adjoint( if not fix_probe: - electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_abs = xp.abs(electrostatic_conj) electrostatic_normalization = xp.sum( electrostatic_abs**2, axis=0, @@ -1055,7 +1057,7 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(electrostatic_normalization)) ** 2 ) - # V* + # exp(-i c*) current_probe += step_size * ( xp.sum( electrostatic_conj * exit_waves, @@ -1114,11 +1116,6 @@ def _object_constraints( current_object[0], tv_denoise_weight, tv_denoise_inner_iter ) - # amplitude threshold - # current_object[0] = self._object_threshold_constraint( - # current_object[0], False - # ) - return current_object def reconstruct( From e6c0dba48d4cdc0605695a405dfb1746e7db09ef Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 30 May 2024 19:17:39 -0700 Subject: [PATCH 33/74] better xray visualizations --- .../phase/xray_magnetic_ptychography.py | 283 +++++++++++------- 1 file changed, 182 insertions(+), 101 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 1912c93b1..91c0bbfaa 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -1658,8 +1658,6 @@ def _visualize_last_iteration( If true, displays a colorbar plot_probe: bool, optional If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed remove_initial_probe_aberrations: bool, optional If true, when plotting fourier probe, removes initial probe to visualize changes @@ -1667,22 +1665,34 @@ def _visualize_last_iteration( asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (12, 5)) - cmap_real = kwargs.pop("cmap_real", "PiYG") - cmap_imag = kwargs.pop("cmap_imag", "PuOr") + figsize = kwargs.pop("figsize", (12, 8)) + cmap_e_real = kwargs.pop("cmap_e_real", "cividis") + cmap_e_imag = kwargs.pop("cmap_e_imag", "magma") + cmap_m_real = kwargs.pop("cmap_m_real", "PuOr") + cmap_m_imag = kwargs.pop("cmap_m_imag", "PiYG") chroma_boost = kwargs.pop("chroma_boost", 1) # get scaled arrays - probe = self._return_single_probe() obj = self.object_cropped - _, _, _vmax_real = return_scaled_histogram_ordering(obj[1].real) - vmin_real = kwargs.pop("vmin_real", -_vmax_real) - vmax_real = kwargs.pop("vmax_real", _vmax_real) + vmin_e_real = kwargs.pop("vmin_e_real", None) + vmax_e_real = kwargs.pop("vmax_e_real", None) + vmin_e_imag = kwargs.pop("vmin_e_imag", None) + vmax_e_imag = kwargs.pop("vmax_e_imag", None) + _, vmin_e_real, vmax_e_real = return_scaled_histogram_ordering( + obj[0].real, vmin_e_real, vmax_e_real + ) + _, vmin_e_imag, vmax_e_iamg = return_scaled_histogram_ordering( + obj[0].imag, vmin_e_imag, vmax_e_imag + ) + + _, _, _vmax_m_real = return_scaled_histogram_ordering(obj[1].real) + vmin_m_real = kwargs.pop("vmin_m_real", -_vmax_m_real) + vmax_m_real = kwargs.pop("vmax_m_real", _vmax_m_real) - _, _, _vmax_imag = return_scaled_histogram_ordering(obj[1].imag) - vmin_imag = kwargs.pop("vmin_imag", -_vmax_imag) - vmax_imag = kwargs.pop("vmax_imag", _vmax_imag) + _, _, _vmax_m_imag = return_scaled_histogram_ordering(obj[1].imag) + vmin_m_imag = kwargs.pop("vmin_m_imag", -_vmax_m_imag) + vmax_m_imag = kwargs.pop("vmax_m_imag", _vmax_m_imag) extent = [ 0, @@ -1698,7 +1708,6 @@ def _visualize_last_iteration( self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, ] - elif plot_probe: probe_extent = [ 0, @@ -1708,11 +1717,11 @@ def _visualize_last_iteration( ] if plot_convergence: - if plot_probe or plot_fourier_probe: + if plot_probe: spec = GridSpec( ncols=3, - nrows=2, - height_ratios=[4, 1], + nrows=3, + height_ratios=[4, 4, 1], hspace=0.15, width_ratios=[ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), @@ -1723,13 +1732,13 @@ def _visualize_last_iteration( ) else: - spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) + spec = GridSpec(ncols=2, nrows=3, height_ratios=[4, 4, 1], hspace=0.15) else: - if plot_probe or plot_fourier_probe: + if plot_probe: spec = GridSpec( ncols=3, - nrows=1, + nrows=2, width_ratios=[ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), @@ -1739,79 +1748,143 @@ def _visualize_last_iteration( ) else: - spec = GridSpec(ncols=2, nrows=1) + spec = GridSpec(ncols=2, nrows=2, wspace=0.35) if fig is None: fig = plt.figure(figsize=figsize) - if plot_probe or plot_fourier_probe: - # Object_real - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[1].real, - extent=extent, - cmap=cmap_real, - vmin=vmin_real, - vmax=vmax_real, - **kwargs, + # Electronic real + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0].real, + extent=extent, + cmap=cmap_e_real, + vmin=vmin_e_real, + vmax=vmax_e_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real elec. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Electronic imag + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[0].imag, + extent=extent, + cmap=cmap_e_imag, + vmin=vmin_e_imag, + vmax=vmax_e_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imag elec. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Magnetic real + ax = fig.add_subplot(spec[1, 0]) + im = ax.imshow( + obj[1].real, + extent=extent, + cmap=cmap_m_real, + vmin=vmin_m_real, + vmax=vmax_m_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real mag. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Magnetic imag + ax = fig.add_subplot(spec[1, 1]) + im = ax.imshow( + obj[1].imag, + extent=extent, + cmap=cmap_m_imag, + vmin=vmin_m_imag, + vmax=vmax_m_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imag mag. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_fourier_probe: + # Fourier probe + intensities = self._return_probe_intensities(None) + titles = [ + f"{sign}ve Fourier probe: {ratio*100:.1f}%" + for sign, ratio in zip(self._magnetic_contribution_sign, intensities) + ] + ax = fig.add_subplot(spec[0, 2]) + + probe_fourier = asnumpy( + self._return_fourier_probe( + self._probes_all[0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Real magnetic refractive index") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + probe_array = Complex2RGB( + probe_fourier, + chroma_boost=chroma_boost, + ) + + ax.set_title(titles[0]) + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") - # Object_imag - ax = fig.add_subplot(spec[0, 1]) im = ax.imshow( - obj[1].imag, - extent=extent, - cmap=cmap_imag, - vmin=vmin_imag, - vmax=vmax_imag, - **kwargs, + probe_array, + extent=probe_extent, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Imaginary magnetic refractive index") if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - # Probe - ax = fig.add_subplot(spec[0, 2]) - if plot_fourier_probe: - probe = asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) + ax = fig.add_subplot(spec[1, 2]) - probe_array = Complex2RGB( - probe, - chroma_boost=chroma_boost, + probe_fourier = asnumpy( + self._return_fourier_probe( + self._probes_all[-1], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, ) + ) - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - asnumpy(self._return_centered_probe(probe)), - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") + probe_array = Complex2RGB( + probe_fourier, + chroma_boost=chroma_boost, + ) + + ax.set_title(titles[-1]) + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") im = ax.imshow( probe_array, @@ -1823,51 +1896,59 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - else: - # Object_real - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[1].real, - extent=extent, - cmap=cmap_real, - vmin=vmin_real, - vmax=vmax_real, - **kwargs, + elif plot_probe: + # Real probe + intensities = self._return_probe_intensities(None) + titles = [ + f"{sign}ve probe intensity: {ratio*100:.1f}%" + for sign, ratio in zip(self._magnetic_contribution_sign, intensities) + ] + ax = fig.add_subplot(spec[0, 2]) + + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(self._probes_all[0])), + power=2, + chroma_boost=chroma_boost, ) + ax.set_title(titles[0]) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Real magnetic refractive index") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - # Object_imag - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1].imag, - extent=extent, - cmap=cmap_imag, - vmin=vmin_imag, - vmax=vmax_imag, - **kwargs, + ax = fig.add_subplot(spec[1, 2]) + + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(self._probes_all[-1])), + power=2, + chroma_boost=chroma_boost, ) + ax.set_title(titles[-1]) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Imaginary magnetic refractive index") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) if plot_convergence and hasattr(self, "error_iterations"): errors = np.array(self.error_iterations) - ax = fig.add_subplot(spec[1, :]) + ax = fig.add_subplot(spec[2, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") ax.set_xlabel("Iteration number") From c964ab7773fe8e98c253f8ef2105fd956633efb4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 30 May 2024 19:17:53 -0700 Subject: [PATCH 34/74] cleanup magnetic viz --- .../process/phase/magnetic_ptychography.py | 139 ++++++------------ 1 file changed, 45 insertions(+), 94 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 58c826b87..aa89d09b3 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1796,55 +1796,55 @@ def _visualize_last_iteration( if fig is None: fig = plt.figure(figsize=figsize) - if plot_probe or plot_fourier_probe: - # Object_e - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[0], - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + # Object_e + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0], + extent=extent, + cmap=cmap_e, + vmin=vmin_e, + vmax=vmax_e, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") - # Object_m - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1], - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Electrostatic potential") + elif self._object_type == "complex": + ax.set_title("Electrostatic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_m + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1], + extent=extent, + cmap=cmap_m, + vmin=vmin_m, + vmax=vmax_m, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Magnetic potential") - elif self._object_type == "complex": - ax.set_title("Magnetic phase") + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + if plot_probe or plot_fourier_probe: # Probe ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: @@ -1883,55 +1883,6 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - else: - # Object_e - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[0], - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Object_e - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1], - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Magnetic potential") - elif self._object_type == "complex": - ax.set_title("Magnetic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "error_iterations"): errors = np.array(self.error_iterations) From 6a48ab0cda125e400fd4386ad66cb1f3a91dd6c2 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 31 May 2024 11:14:21 -0400 Subject: [PATCH 35/74] small bug in dp_mask for gpu --- py4DSTEM/process/phase/phase_base_class.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 742013f1e..27476cb43 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -686,6 +686,7 @@ def _calculate_intensities_center_of_mass( # calculate CoM if dp_mask is not None: + dp_mask = copy_to_device(dp_mask, device) intensities_mask = intensities * dp_mask else: intensities_mask = intensities From d99138010b7bca543e9e60acbbab469e2336a93c Mon Sep 17 00:00:00 2001 From: Yael_Tsarfati Date: Wed, 5 Jun 2024 14:43:18 -0700 Subject: [PATCH 36/74] plot radial peaks v lines added --- py4DSTEM/process/polar/polar_peaks.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 535ae7143..73232fa8e 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -660,11 +660,28 @@ def plot_radial_peaks( qstep=None, label_y_axis=False, figsize=(8, 4), + v_lines=None, returnfig=False, ): """ Calculate and plot the total peak signal as a function of the radial coordinate. + q_pixel_units + If True, plot in reciprocal units instead of pixels. + qmin + The minimum q for plotting. + qmax + The maximum q for plotting. + qstep + The bin width. + label_y_axis + If True, label y axis. + figsize + Plot size. + v_lines: tuple + x coordinates for plotting vertical lines. + returnfig + If True, returns figure. """ # Get all peak data @@ -741,6 +758,17 @@ def plot_radial_peaks( if not label_y_axis: ax.tick_params(left=False, labelleft=False) + plt.tight_layout() + + if v_lines is not None: + y_min, y_max = ax.get_ylim() + + if np.isscalar(v_lines): + ax.vlines(v_lines, y_min, y_max, color="g") + else: + for a0 in range(len(v_lines)): + ax.vlines(v_lines[a0], y_min, y_max, color="g") + if returnfig: return fig, ax From 9973e208b2e3957f25166b1af1635329eb5b6330 Mon Sep 17 00:00:00 2001 From: Yael_Tsarfati Date: Wed, 5 Jun 2024 14:44:52 -0700 Subject: [PATCH 37/74] Typo --- py4DSTEM/process/polar/polar_peaks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 73232fa8e..cdf3cbca6 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -758,8 +758,6 @@ def plot_radial_peaks( if not label_y_axis: ax.tick_params(left=False, labelleft=False) - plt.tight_layout() - if v_lines is not None: y_min, y_max = ax.get_ylim() From 4a183dfb2c7bd868e2b060f73e71178d21d7309c Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 21 Jun 2024 13:11:15 -0700 Subject: [PATCH 38/74] Changing default settings, fiddling with plots --- py4DSTEM/process/diffraction/crystal.py | 2 +- py4DSTEM/process/diffraction/crystal_phase.py | 219 ++++++++++++++---- py4DSTEM/process/diffraction/crystal_viz.py | 44 +++- 3 files changed, 213 insertions(+), 52 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 9694111c9..ed4f9ca6b 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1368,7 +1368,7 @@ def excitation_errors( g, foil_normal=None, precession_angle_degrees=None, - precession_steps=180, + precession_steps=72, ): """ Calculate the excitation errors, assuming k0 = [0, 0, -1/lambda]. diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index f6d916a95..c18865071 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -78,16 +78,19 @@ def quantify_single_pattern( pointlistarray: PointListArray, xy_position = (0,0), corr_kernel_size = 0.04, - sigma_excitation_error = 0.02, - power_experiment = 0.25, - power_calculated = 0.25, - max_number_patterns = 3, + sigma_excitation_error: float = 0.02, + precession_angle_degrees = None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max = None, + max_number_patterns = 2, single_phase = False, - allow_strain = True, + allow_strain = False, strain_iterations = 3, strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, + weight_unmatched_peaks = 1.0, plot_result = True, plot_only_nonzero_phases = True, plot_unmatched_peaks = False, @@ -102,10 +105,32 @@ def quantify_single_pattern( ): """ Quantify the phase for a single diffraction pattern. + + Parameters + ---------- + + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error: (float) + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees: (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + + power_radial: (float) + Power for scaling the correlation intensity as a function of the peak radius + power_intensity: (float) + Power for scaling the correlation intensity as a function of simulated peak intensity + power_intensity_experiment: (float): + Power for scaling the correlation intensity as a function of experimental peak intensity + k_max: (float) + Max k values included in fits, for both x and y directions. + """ + # tolerance - tol2 = 4e-4 + tol2 = 1e-6 # calibrations center = pointlistarray.calstate['center'] @@ -137,7 +162,16 @@ def quantify_single_pattern( pixel = pixel, rotate = rotate) # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() - keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + if k_max is None: + keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + else: + keep = np.logical_and.reduce(( + bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2, + np.abs(bragg_peaks.data["qx"]) < k_max, + np.abs(bragg_peaks.data["qy"]) < k_max, + )) + + # ind_center_beam = np.argmin( # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') @@ -147,12 +181,12 @@ def quantify_single_pattern( qy = bragg_peaks.data["qy"][keep] qx0 = bragg_peaks.data["qx"][np.logical_not(keep)] qy0 = bragg_peaks.data["qy"][np.logical_not(keep)] - if power_experiment == 0: + if power_intensity_experiment == 0: intensity = np.ones_like(qx) intensity0 = np.ones_like(qx0) else: - intensity = bragg_peaks.data["intensity"][keep]**power_experiment - intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_experiment + intensity = bragg_peaks.data["intensity"][keep]**power_intensity_experiment + intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_intensity_experiment int_total = np.sum(intensity) # init basis array @@ -190,16 +224,25 @@ def quantify_single_pattern( ), ind_orientation = m, sigma_excitation_error = sigma_excitation_error, + precession_angle_degrees = precession_angle_degrees, ) - del_peak = bragg_peaks_fit.data["qx"]**2 \ - + bragg_peaks_fit.data["qy"]**2 < tol2 + if k_max is None: + del_peak = bragg_peaks_fit.data["qx"]**2 \ + + bragg_peaks_fit.data["qy"]**2 < tol2 + else: + del_peak = np.logical_or.reduce(( + bragg_peaks_fit.data["qx"]**2 \ + + bragg_peaks_fit.data["qy"]**2 < tol2, + np.abs(bragg_peaks_fit.data["qx"]) > k_max, + np.abs(bragg_peaks_fit.data["qy"]) > k_max, + )) bragg_peaks_fit.remove(del_peak) # peak intensities - if power_calculated == 0: + if power_intensity == 0: int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) else: - int_fit = bragg_peaks_fit.data['intensity']**power_calculated + int_fit = bragg_peaks_fit.data['intensity']**power_intensity # Pair peaks to experiment if plot_result: @@ -321,20 +364,38 @@ def quantify_single_pattern( if single_phase: # loop through each crystal structure and determine the best fit structure, - # which can contain multiple orienations. + # which can contain multiple orientations up to max_number_patterns crystal_res = np.zeros(self.num_crystals) for a0 in range(self.num_crystals): - sub = self.crystal_identity[:,0] == a0 + inds_solve = self.crystal_identity[:,0] == a0 + search = True - phase_weights_cand, phase_residual_cand = nnls( - basis[:,sub], - obs, - ) - phase_weights[sub] = phase_weights_cand - crystal_res[a0] = phase_residual_cand + while search is True: + + basis_solve = basis[:,inds_solve] + obs_solve = obs.copy() + + if weight_unmatched_peaks > 1.0: + sub_unmatched = np.sum(basis_solve,axis=1)<1e-8 + obs_solve[sub_unmatched] *= weight_unmatched_peaks + + phase_weights_cand, phase_residual_cand = nnls( + basis_solve, + obs_solve, + ) + + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_weights[inds_solve] = phase_weights_cand + crystal_res[a0] = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False ind_best_fit = np.argmin(crystal_res) + # ind_best_fit = np.argmax(phase_weights) + phase_residual = crystal_res[ind_best_fit] sub = np.logical_not( self.crystal_identity[:,0] == ind_best_fit @@ -452,19 +513,22 @@ def quantify_single_pattern( ax.scatter( qy0, qx0, - s = scale_markers_experiment * intensity0, + # s = scale_markers_experiment * intensity0, + s = scale_markers_experiment * bragg_peaks.data["intensity"][np.logical_not(keep)], marker = "o", facecolor = [0.7, 0.7, 0.7], ) ax.scatter( qy, qx, - s = scale_markers_experiment * intensity, + # s = scale_markers_experiment * intensity, + s = scale_markers_experiment * bragg_peaks.data["intensity"][keep], marker = "o", facecolor = [0.7, 0.7, 0.7], ) # legend - k_max = np.max(self.k_max) + if k_max is None: + k_max = np.max(self.k_max) dx_leg = -0.05*k_max dy_leg = 0.04*k_max text_params = { @@ -607,7 +671,7 @@ def quantify_single_pattern( # appearance ax.set_xlim((-k_max, k_max)) - ax.set_ylim((-k_max, k_max)) + ax.set_ylim((k_max, -k_max)) ax_leg.set_xlim((-0.1*k_max, 0.4*k_max)) ax_leg.set_ylim((-0.5*k_max, 0.5*k_max)) @@ -623,9 +687,14 @@ def quantify_phase( pointlistarray: PointListArray, corr_kernel_size = 0.04, sigma_excitation_error = 0.02, - power_experiment = 0.25, - power_calculated = 0.25, - max_number_patterns = 3, + + precession_angle_degrees = None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max = None, + # power_experiment = 0.25, + # power_calculated = 0.25, + max_number_patterns = 2, single_phase = False, allow_strain = True, strain_iterations = 3, @@ -670,8 +739,12 @@ def quantify_phase( xy_position = (rx,ry), corr_kernel_size = corr_kernel_size, sigma_excitation_error = sigma_excitation_error, - power_experiment = power_experiment, - power_calculated = power_calculated, + + precession_angle_degrees = precession_angle_degrees, + power_intensity = power_intensity, + power_intensity_experiment = power_intensity_experiment, + k_max = k_max, + max_number_patterns = max_number_patterns, single_phase = single_phase, allow_strain = allow_strain, @@ -923,10 +996,15 @@ def plot_phase_maps( def plot_dominant_phase( self, + use_correlation_scores = False, reliability_range = (0.0,1.0), sigma = 0.0, phase_colors = None, + ticks = True, figsize = (6,6), + legend_add = True, + legend_fraction = 0.2, + print_fractions = False, ): """ Plot a combined figure showing the primary phase at each probe position. @@ -950,10 +1028,18 @@ def plot_dominant_phase( phase_corr_2nd = np.zeros(scan_shape) phase_sig = np.zeros((self.num_crystals,scan_shape[0],scan_shape[1])) - # sum up phase weights by crystal type - for a0 in range(self.num_fits): - ind = self.crystal_identity[a0,0] - phase_sig[ind] += self.phase_weights[:,:,a0] + if use_correlation_scores: + # Calculate scores from highest correlation match + for a0 in range(self.num_crystals): + phase_sig[a0] = np.maximum( + phase_sig[a0], + np.max(self.crystals[a0].orientation_map.corr,axis=2), + ) + else: + # sum up phase weights by crystal type + for a0 in range(self.num_fits): + ind = self.crystal_identity[a0,0] + phase_sig[ind] += self.phase_weights[:,:,a0] # smoothing of the outputs if sigma > 0.0: @@ -992,25 +1078,68 @@ def plot_dominant_phase( 0, 1) + # Print the total area of fraction of each phase + if print_fractions: + phase_mask = phase_scale >= 0.5 + phase_total = np.sum(phase_mask) + + print('Phase Fractions') + print('---------------') + for a0 in range(self.num_crystals): + phase_frac = np.sum((phase_map == a0) * phase_mask) / phase_total - phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) + print(self.crystal_names[a0] + ' - ' + f'{phase_frac*100:.4f}' + '%') + + + self.phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) for a0 in range(self.num_crystals): sub = phase_map==a0 for a1 in range(3): - phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] + self.phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] # normalize - # phase_rgb = np.clip( - # (phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), + # self.phase_rgb = np.clip( + # (self.phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), # 0,1) - fig,ax = plt.subplots(figsize=figsize) + fig = plt.figure(figsize=figsize) + if legend_add: + width = 1 + + ax = fig.add_axes((0,legend_fraction,1,1-legend_fraction)) + ax_leg = fig.add_axes((0,0,1,legend_fraction)) + + for a0 in range(self.num_crystals): + ax_leg.scatter( + a0*width, + 0, + s = 200, + marker = 's', + edgecolor = (0,0,0,1), + facecolor = phase_colors[a0], + ) + ax_leg.text( + a0*width+0.1, + 0, + self.crystal_names[a0], + fontsize = 16, + verticalalignment = 'center', + ) + ax_leg.axis('off') + ax_leg.set_xlim(( + width * -0.5, + width * (self.num_crystals+0.5), + )) + + else: + ax = fig.add_axes((0,0,1,1)) + ax.imshow( - phase_rgb, + self.phase_rgb, # vmin = 0, # vmax = 5, - # phase_rgb, + # self.phase_rgb, # phase_corr - phase_corr_2nd, # cmap = 'turbo', # vmin = 0, @@ -1018,6 +1147,12 @@ def plot_dominant_phase( # cmap = 'gray', ) + if ticks is False: + ax.set_xticks([]) + ax.set_yticks([]) + + + return fig,ax diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index fdb313cbb..649f26faa 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -393,7 +393,7 @@ def plot_scattering_intensity( k_step=0.001, k_broadening=0.0, k_power_scale=0.0, - int_power_scale=0.5, + int_power_scale=1.0, int_scale=1.0, remove_origin=True, bragg_peaks=None, @@ -948,12 +948,27 @@ def plot_diffraction_pattern( scale_markers: float = 500, scale_markers_compare: Optional[float] = None, power_markers: float = 1, + power_markers_compare: float = 1, + + color = (0.0, 0.0, 0.0), + color_compare = None, + facecolor = None, + facecolor_compare = (0.0, 0.7, 1.0), + edgecolor = None, + edgecolor_compare = None, + linewidth = 1, + linewidth_compare = 1, + + marker = '+', + marker_compare = 'o', + plot_range_kx_ky: Optional[Union[list, tuple, np.ndarray]] = None, add_labels: bool = True, shift_labels: float = 0.08, shift_marker: float = 0.005, min_marker_size: float = 1e-6, max_marker_size: float = 1000, + show_axes: bool = True, figsize: Union[list, tuple, np.ndarray] = (12, 6), returnfig: bool = False, input_fig_handle=None, @@ -990,7 +1005,7 @@ def plot_diffraction_pattern( marker_size = scale_markers * bragg_peaks.data["intensity"] else: marker_size = scale_markers * ( - bragg_peaks.data["intensity"] ** (power_markers / 2) + bragg_peaks.data["intensity"] ** power_markers ) # Apply marker size limits to primary plot @@ -1013,7 +1028,7 @@ def plot_diffraction_pattern( else: marker_size_compare = np.clip( scale_markers_compare - * (bragg_peaks_compare.data["intensity"] ** (power_markers / 2)), + * (bragg_peaks_compare.data["intensity"] ** power_markers), min_marker_size, max_marker_size, ) @@ -1022,19 +1037,30 @@ def plot_diffraction_pattern( bragg_peaks_compare.data["qy"], bragg_peaks_compare.data["qx"], s=marker_size_compare, - marker="o", - facecolor=[0.0, 0.7, 1.0], + marker=marker_compare, + facecolor=facecolor_compare, + edgecolor=edgecolor_compare, + color=color_compare, + linewidth=linewidth_compare, ) ax.scatter( bragg_peaks.data["qy"], bragg_peaks.data["qx"], s=marker_size, - marker="+", - facecolor="k", + marker=marker, + facecolor=facecolor, + edgecolor=edgecolor, + color=color, + linewidth=linewidth, ) - ax.set_xlabel("$q_y$ [Å$^{-1}$]") - ax.set_ylabel("$q_x$ [Å$^{-1}$]") + if show_axes: + ax.set_xlabel("$q_y$ [Å$^{-1}$]") + ax.set_ylabel("$q_x$ [Å$^{-1}$]") + else: + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + if plot_range_kx_ky is not None: plot_range_kx_ky = np.array(plot_range_kx_ky) From 7f3eca729c618322b24ee1f07821f061e3202487 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 24 Jun 2024 09:01:21 -0700 Subject: [PATCH 39/74] adding robust fitting to polar --- py4DSTEM/process/polar/polar_fits.py | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index 7760726a7..ea83cc866 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -16,6 +16,9 @@ def fit_amorphous_ring( show_fit_mask=False, fit_all_images=False, maxfev=None, + robust=False, + robust_steps = 3, + robust_thresh = 1.0, verbose=False, plot_result=True, plot_log_scale=False, @@ -50,6 +53,13 @@ def fit_amorphous_ring( Fit the elliptic parameters to all images maxfev: int Max number of fitting evaluations for curve_fit. + robust: bool + Set to True to use robust fitting. + robust_steps: int + Number of robust fitting steps. + robust_thresh: float + Threshold for relative errors for outlier detection. Setting to 1.0 means all points beyond + one standard deviation of the median error will be excluded from the next fit. verbose: bool Print fit results plot_result: bool @@ -206,8 +216,41 @@ def fit_amorphous_ring( maxfev=maxfev, )[0] coefs[4] = np.mod(coefs[4], 2 * np.pi) + + if robust: + for a0 in range(robust_steps): + # find outliers + int_fit = amorphous_model(basis, *coefs) + int_diff = vals / int_mean - int_fit + int_diff /= np.median(np.abs(int_diff)) + sub_fit = int_diff**2 < robust_thresh**2 + + # redo fits excluding the outliers + if maxfev is None: + coefs = curve_fit( + amorphous_model, + basis[:,sub_fit], + vals[sub_fit] / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + )[0] + else: + coefs = curve_fit( + amorphous_model, + basis[:,sub_fit], + vals[sub_fit] / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + maxfev=maxfev, + )[0] + coefs[4] = np.mod(coefs[4], 2 * np.pi) + + # Scale intensity coefficients coefs[5:8] *= int_mean + # Perform the fit on each individual diffration pattern if fit_all_images: coefs_all = np.zeros((datacube.shape[0], datacube.shape[1], coefs.size)) From b921a47b12ff004aedf0dda00fc68ee362887c59 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 09:08:20 -0700 Subject: [PATCH 40/74] Adding docstrings, cleaning up code --- py4DSTEM/process/diffraction/crystal.py | 12 ++-- py4DSTEM/process/diffraction/crystal_phase.py | 72 +++++++++++++++++-- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index ed4f9ca6b..9aa436cbc 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -719,7 +719,11 @@ def generate_diffraction_pattern( return_orientation_matrix=False, ): """ - Generate a single diffraction pattern, return all peaks as a pointlist. + Generate a single diffraction pattern, return all peaks as a pointlist. This function performs a + kinematical calculation, with optional precession of the beam. + + TODO - switch from numerical precession to analytic (requires geometry projection). + TODO - verify projection geometry for 2D material diffraction. Parameters ---------- @@ -740,7 +744,6 @@ def generate_diffraction_pattern( cartesian projection direction proj_x_cartesian (3,) numpy.array cartesian projection direction - foil_normal 3 element foil normal - set to None to use zone_axis proj_x_axis (3,) numpy.array @@ -748,7 +751,6 @@ def generate_diffraction_pattern( accel_voltage (float) Accelerating voltage in Volts. If not specified, we check to see if crystal already has voltage specified. - sigma_excitation_error (float) sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms tol_excitation_error_mult (float) @@ -768,8 +770,8 @@ def generate_diffraction_pattern( ---------- bragg_peaks (PointList) list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] - orientation_matrix (array) - 3x3 orientation matrix (optional) + orientation_matrix (array, optional) + 3x3 orientation matrix """ diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index c18865071..cb99a99a9 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -106,9 +106,15 @@ def quantify_single_pattern( """ Quantify the phase for a single diffraction pattern. + TODO - determine the difference between false positive peaks and unmatched peaks (if any). + Parameters ---------- + pointlistarray: (PointListArray) + Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) + xy_position: (int,int) + The (x,y) or (row,column) position to be quantified. corr_kernel_size: (float) Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] @@ -116,19 +122,73 @@ def quantify_single_pattern( The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees: (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - - power_radial: (float) - Power for scaling the correlation intensity as a function of the peak radius power_intensity: (float) - Power for scaling the correlation intensity as a function of simulated peak intensity + Power for scaling the correlation intensity as a function of simulated peak intensity. power_intensity_experiment: (float): - Power for scaling the correlation intensity as a function of experimental peak intensity + Power for scaling the correlation intensity as a function of experimental peak intensity. k_max: (float) Max k values included in fits, for both x and y directions. + max_number_patterns: int + Max number of orientations which can be included in a match. + single_phase: bool + Set to true to force result to output only the best-fit phase (minimum intensity residual). + allow_strain: bool, + Allow the simulated diffraction patterns to be distorted to improve the matches. + strain_iterations: int + Number of pattern position refinement iterations. + strain_max: float + Maximum strain fraction allowed - this value should be low, typically a few percent (~0.02). + include_false_positives: bool + Penalize patterns which generate false positive peaks. + weight_false_positives: float + Weight strength of false positive peaks. + weight_unmatched_peaks: float + Penalize unmatched peaks. + plot_result: bool + Plot the resulting fit. + plot_only_nonzero_phases: bool + Only plot phases with phase weights > 0. + plot_unmatched_peaks: bool + Plot the false postive peaks. + plot_correlation_radius: bool + In the visualization, draw the correlation radius. + scale_markers_experiment: float + Size of experimental diffraction peak markers. + scale_markers_calculated: float + Size of the calculate diffraction peak markers. + crystal_inds_plot: tuple of ints + Which crystal index / indices to plot. + phase_colors: np.array + Color of each phase, should have shape = (num_phases, 3) + figsize: (float,float) + Size of the output figure. + verbose: bool + Print the resulting fit weights to console. + returnfig: bool + Return the figure and axis handles for the plot. + + + Returns + ------- + phase_weights: (np.array) + Estimated relative fraction of each phase for all probe positions. + shape = (num_x, num_y, num_orientations) + where num_orientations is the total number of all orientations for all phases. + phase_residual: (np.array) + Residual intensity not represented by the best fit phase weighting for all probe positions. + shape = (num_x, num_y) + phase_reliability: (np.array) + Estimated reliability of match(es) for all probe positions. + Typically calculated as the best fit score minus the second best fit. + shape = (num_x, num_y) + int_total: (np.array) + Sum of experimental peak intensities for all probe positions. + shape = (num_x, num_y) + fig,ax: (optional) + matplotlib figure and axis handles """ - # tolerance tol2 = 1e-6 From cd48bea268ac02ee29f70372ed2e19e731966690 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 15:53:17 -0700 Subject: [PATCH 41/74] more docstrings added --- py4DSTEM/process/diffraction/crystal_phase.py | 141 +++++++++++++++++- 1 file changed, 136 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index cb99a99a9..f05d304c1 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -747,13 +747,10 @@ def quantify_phase( pointlistarray: PointListArray, corr_kernel_size = 0.04, sigma_excitation_error = 0.02, - precession_angle_degrees = None, power_intensity: float = 0.25, power_intensity_experiment: float = 0.25, k_max = None, - # power_experiment = 0.25, - # power_calculated = 0.25, max_number_patterns = 2, single_phase = False, allow_strain = True, @@ -765,6 +762,44 @@ def quantify_phase( ): """ Quantify phase of all diffraction patterns. + + Parameters + ---------- + pointlistarray: (PointListArray) + Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error: (float) + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees: (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + power_intensity: (float) + Power for scaling the correlation intensity as a function of simulated peak intensity. + power_intensity_experiment: (float): + Power for scaling the correlation intensity as a function of experimental peak intensity. + k_max: (float) + Max k values included in fits, for both x and y directions. + max_number_patterns: int + Max number of orientations which can be included in a match. + single_phase: bool + Set to true to force result to output only the best-fit phase (minimum intensity residual). + allow_strain: bool, + Allow the simulated diffraction patterns to be distorted to improve the matches. + strain_iterations: int + Number of pattern position refinement iterations. + strain_max: float + Maximum strain fraction allowed - this value should be low, typically a few percent (~0.02). + include_false_positives: bool + Penalize patterns which generate false positive peaks. + weight_false_positives: float + Weight strength of false positive peaks. + progressbar: bool + Display progress. + + Returns + ----------- + """ # init results arrays @@ -836,6 +871,33 @@ def plot_phase_weights( ): """ Plot the individual phase weight maps and residuals. + + Parameters + ---------- + weight_range: (float, float) + Plotting weight range. + weight_normalize: bool + Normalize weights before plotting. + total_intensity_normalize: bool + Normalize the total intensity. + cmap: matplotlib.cm.cmap + Colormap to use for plots. + show_ticks: bool + Show ticks on plots. + show_axes: bool + Show axes. + layout: int + Layout type for figures. + figsize: (float,float) + Size of figure panel. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + fig,ax: (optional) + Figure and axes handles. + """ # Normalization if required to total DF peak intensity @@ -931,6 +993,45 @@ def plot_phase_maps( ): """ Plot the individual phase weight maps and residuals. + + Parameters + ---------- + weight_threshold: float + Threshold for showing each phase. + weight_normalize: bool + Normalize weights before plotting. + total_intensity_normalize: bool + Normalize the total intensity. + plot_combine: bool + Combine all figures into a single plot. + crystal_inds_plot: (tuple of ints) + Which crystals to plot phase maps for. + phase_colors: np.array + (Nx3) shaped array giving the colors for each phase + show_ticks: bool + Show ticks on plots. + show_axes: bool + Show axes. + layout: int + Layout type for figures. + figsize: (float,float) + Size of figure panel. + return_phase_estimate: bool + Return the phase estimate array. + return_rgb_images: bool + Return the rgb images. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + im_all: (np.array, optional) + images showing phase maps. + im_rgb, im_rgb_all: (np.array, optional) + rgb colored output images, possibly combined + fig,ax: (optional) + Figure and axes handles. + """ if phase_colors is None: @@ -1065,10 +1166,40 @@ def plot_dominant_phase( legend_add = True, legend_fraction = 0.2, print_fractions = False, + returnfig = True, ): """ Plot a combined figure showing the primary phase at each probe position. Mask by the reliability index (best match minus 2nd best match). + + Parameters + ---------- + use_correlation_scores: bool + Set to True to use correlation scores instead of reliabiltiy from intensity residuals. + reliability_range: (float, float) + Plotting intensity range + sigma: float + Smoothing in units of probe position. + phase_colors: np.array + (N,3) shaped array giving colors of all phases + ticks: bool + Show ticks on plots. + figsize: (float,float) + Size of output figure + legend_add: bool + Add legend to plot + legend_fraction: float + Fractional size of legend in plot. + print_fractions: bool + Print the estimated fraction of all phases. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + fig,ax: (optional) + Figure and axes handles. + """ if phase_colors is None: @@ -1212,8 +1343,8 @@ def plot_dominant_phase( ax.set_yticks([]) - - return fig,ax + if returnfig: + return fig,ax # def plot_all_phase_maps(self, map_scale_values=None, index=0): From 316bded092e46e73f59186d67d4a1e6927f3cb68 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 15:53:45 -0700 Subject: [PATCH 42/74] Cleaning up, removing residual comments --- py4DSTEM/process/diffraction/crystal_phase.py | 293 ------------------ 1 file changed, 293 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index f05d304c1..750ea4670 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -1346,296 +1346,3 @@ def plot_dominant_phase( if returnfig: return fig,ax - - # def plot_all_phase_maps(self, map_scale_values=None, index=0): - # """ - # Visualize phase maps of dataset. - - # Args: - # map_scale_values (float): Value to scale correlations by - # """ - # phase_maps = [] - # if map_scale_values is None: - # map_scale_values = [1] * len(self.orientation_maps) - # corr_sum = np.sum( - # [ - # (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m]) - # for m in range(len(self.orientation_maps)) - # ] - # ) - # for m in range(len(self.orientation_maps)): - # phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum) - # show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno") - # return - - # def plot_phase_map(self, index=0, cmap=None): - # corr_array = np.dstack( - # [maps.corr[:, :, index] for maps in self.orientation_maps] - # ) - # best_corr_score = np.max(corr_array, axis=2) - # best_match_phase = [ - # np.where(corr_array[:, :, p] == best_corr_score, True, False) - # for p in range(len(self.orientation_maps)) - # ] - - # if cmap is None: - # cm = plt.get_cmap("rainbow") - # cmap = [ - # cm(1.0 * i / len(self.orientation_maps)) - # for i in range(len(self.orientation_maps)) - # ] - - # fig, (ax) = plt.subplots(figsize=(6, 6)) - # ax.matshow( - # np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), - # cmap="gray", - # ) - # ax.axis("off") - - # for m in range(len(self.orientation_maps)): - # c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[ - # m - # ] - # cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10) - # ax.matshow( - # np.ma.array( - # self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m] - # ), - # cmap=cm, - # ) - # plt.show() - - # return - - # Potentially introduce a way to check best match out of all orientations in phase plan and plug into model - # to quantify phase - - # def phase_plan( - # self, - # method, - # zone_axis_range: np.ndarray = np.array([[0, 1, 1], [1, 1, 1]]), - # angle_step_zone_axis: float = 2.0, - # angle_coarse_zone_axis: float = None, - # angle_refine_range: float = None, - # angle_step_in_plane: float = 2.0, - # accel_voltage: float = 300e3, - # intensity_power: float = 0.25, - # tol_peak_delete=None, - # tol_distance: float = 0.01, - # fiber_axis = None, - # fiber_angles = None, - # ): - # return - - # def quantify_phase( - # self, - # pointlistarray, - # tolerance_distance=0.08, - # method="nnls", - # intensity_power=0, - # mask_peaks=None, - # ): - # """ - # Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. - - # Args: - # pointlisarray (pointlistarray): Pointlistarray to quantify phase of - # tolerance_distance (float): Distance allowed between a peak and match - # method (str): Numerical method used to quantify phase - # intensity_power (float): ... - # mask_peaks (list, optional): A pointer of which positions to mask peaks from - - # Details: - # """ - # if isinstance(pointlistarray, PointListArray): - # phase_weights = np.zeros( - # ( - # pointlistarray.shape[0], - # pointlistarray.shape[1], - # np.sum([map.num_matches for map in self.orientation_maps]), - # ) - # ) - # phase_residuals = np.zeros(pointlistarray.shape) - # for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): - # ( - # _, - # phase_weight, - # phase_residual, - # crystal_identity, - # ) = self.quantify_phase_pointlist( - # pointlistarray, - # position=[Rx, Ry], - # tolerance_distance=tolerance_distance, - # method=method, - # intensity_power=intensity_power, - # mask_peaks=mask_peaks, - # ) - # phase_weights[Rx, Ry, :] = phase_weight - # phase_residuals[Rx, Ry] = phase_residual - # self.phase_weights = phase_weights - # self.phase_residuals = phase_residuals - # self.crystal_identity = crystal_identity - # return - # else: - # return TypeError("pointlistarray must be of type pointlistarray.") - # return - - # def quantify_phase_pointlist( - # self, - # pointlistarray, - # position, - # method="nnls", - # tolerance_distance=0.08, - # intensity_power=0, - # mask_peaks=None, - # ): - # """ - # Args: - # pointlisarray (pointlistarray): Pointlistarray to quantify phase of - # position (tuple/list): Position of pointlist in pointlistarray - # tolerance_distance (float): Distance allowed between a peak and match - # method (str): Numerical method used to quantify phase - # intensity_power (float): ... - # mask_peaks (list, optional): A pointer of which positions to mask peaks from - - # Returns: - # pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - # phase_weights (np.ndarray): Weights of each phase - # phase_residuals (np.ndarray): Residuals - # crystal_identity (list): List of lists, where the each entry represents the position in the - # crystal and orientation match that is associated with the phase - # weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - # the first entry [0,0] in phase weights is associated with the first crystal - # the first match within that crystal. [0,1] is the first crystal and the - # second match within that crystal. - # """ - # # Things to add: - # # 1. Better cost for distance from peaks in pointlists - # # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - - # pointlist = pointlistarray.get_pointlist(position[0], position[1]) - # pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0) - # pointlist.remove(pl_mask) - # # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in - - # if intensity_power == 0: - # pl_intensities = np.ones(pointlist["intensity"].shape) - # else: - # pl_intensities = pointlist["intensity"] ** intensity_power - # # Prepare matches for modeling - # pointlist_peak_matches = [] - # crystal_identity = [] - - # for c in range(len(self.crystals)): - # for m in range(self.orientation_maps[c].num_matches): - # crystal_identity.append([c, m]) - # phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape)) - # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - # self.orientation_maps[c].get_orientation(position[0], position[1]), - # ind_orientation=m, - # ) - # # Find the best match peak within tolerance_distance and add value in the right position - # for d in range(pointlist["qx"].shape[0]): - # distances = [] - # for p in range(bragg_peaks_fit["qx"].shape[0]): - # distances.append( - # np.sqrt( - # (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2 - # + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2 - # ) - # ) - # ind = np.where(distances == np.min(distances))[0][0] - - # # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - # if distances[ind] <= tolerance_distance: - # ## Somewhere in this if statement is probably where better distances from the peak should be coded in - # if ( - # intensity_power == 0 - # ): # This could potentially be a different intensity_power arg - # phase_peak_match_intensities[d] = 1 ** ( - # (tolerance_distance - distances[ind]) - # / tolerance_distance - # ) - # else: - # phase_peak_match_intensities[d] = bragg_peaks_fit[ - # "intensity" - # ][ind] ** ( - # (tolerance_distance - distances[ind]) - # / tolerance_distance - # ) - # else: - # ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - # continue - - # pointlist_peak_matches.append(phase_peak_match_intensities) - # pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - # pointlist_peak_intensity_matches = ( - # pointlist_peak_intensity_matches.reshape( - # pl_intensities.shape[0], - # pointlist_peak_intensity_matches.shape[-1], - # ) - # ) - - # if len(pointlist["qx"]) > 0: - # if mask_peaks is not None: - # for i in range(len(mask_peaks)): - # if mask_peaks[i] == None: # noqa: E711 - # continue - # inds_mask = np.where( - # pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 - # )[0] - # for mask in range(len(inds_mask)): - # pointlist_peak_intensity_matches[inds_mask[mask], i] = 0 - - # if method == "nnls": - # phase_weights, phase_residuals = nnls( - # pointlist_peak_intensity_matches, pl_intensities - # ) - - # elif method == "lstsq": - # phase_weights, phase_residuals, rank, singluar_vals = lstsq( - # pointlist_peak_intensity_matches, pl_intensities, rcond=-1 - # ) - # phase_residuals = np.sum(phase_residuals) - # else: - # raise ValueError(method + " Not yet implemented. Try nnls or lstsq.") - # else: - # phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) - # phase_residuals = np.NaN - # return ( - # pointlist_peak_intensity_matches, - # phase_weights, - # phase_residuals, - # crystal_identity, - # ) - - # def plot_peak_matches( - # self, - # pointlistarray, - # position, - # tolerance_distance, - # ind_orientation, - # pointlist_peak_intensity_matches, - # ): - # """ - # A method to view how the tolerance distance impacts the peak matches associated with - # the quantify_phase_pointlist method. - - # Args: - # pointlistarray, - # position, - # tolerance_distance - # pointlist_peak_intensity_matches - # """ - # pointlist = pointlistarray.get_pointlist(position[0],position[1]) - - # for m in range(pointlist_peak_intensity_matches.shape[1]): - # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern( - # self.orientation_maps[m].get_orientation(position[0], position[1]), - # ind_orientation = ind_orientation - # ) - # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m]) - - # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) - # ax1 = plot_diffraction_pattern(pointlist,) - # return From 9ada2d86fb44d6e21165aff2b4499cb3083b2b23 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 15:54:15 -0700 Subject: [PATCH 43/74] black --- py4DSTEM/process/diffraction/crystal.py | 74 +- py4DSTEM/process/diffraction/crystal_ACOM.py | 125 +-- py4DSTEM/process/diffraction/crystal_phase.py | 946 +++++++++--------- py4DSTEM/process/diffraction/crystal_viz.py | 30 +- py4DSTEM/process/polar/polar_fits.py | 9 +- 5 files changed, 620 insertions(+), 564 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 9aa436cbc..5cea7afb4 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -714,7 +714,7 @@ def generate_diffraction_pattern( tol_excitation_error_mult: float = 3, tol_intensity: float = 1e-5, k_max: Optional[float] = None, - precession_angle_degrees = None, + precession_angle_degrees=None, keep_qz=False, return_orientation_matrix=False, ): @@ -730,49 +730,49 @@ def generate_diffraction_pattern( orientation (Orientation) an Orientation class object - ind_orientation + ind_orientation If input is an Orientation class object with multiple orientations, this input can be used to select a specific orientation. - orientation_matrix (3,3) numpy.array + orientation_matrix (3,3) numpy.array orientation matrix, where columns represent projection directions. - zone_axis_lattice (3,) numpy.array + zone_axis_lattice (3,) numpy.array projection direction in lattice indices - proj_x_lattice (3,) numpy.array + proj_x_lattice (3,) numpy.array x-axis direction in lattice indices - zone_axis_cartesian (3,) numpy.array + zone_axis_cartesian (3,) numpy.array cartesian projection direction - proj_x_cartesian (3,) numpy.array + proj_x_cartesian (3,) numpy.array cartesian projection direction - foil_normal + foil_normal 3 element foil normal - set to None to use zone_axis - proj_x_axis (3,) numpy.array + proj_x_axis (3,) numpy.array 3 element vector defining image x axis (vertical) - accel_voltage (float) + accel_voltage (float) Accelerating voltage in Volts. If not specified, we check to see if crystal already has voltage specified. sigma_excitation_error (float) sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms tol_excitation_error_mult (float) tolerance in units of sigma for s_g inclusion - tol_intensity (numpy float) + tol_intensity (numpy float) tolerance in intensity units for inclusion of diffraction spots - k_max (float) + k_max (float) Maximum scattering vector precession_angle_degrees (float) Precession angle for library calculation. Set to None for no precession. - keep_qz (bool) + keep_qz (bool) Flag to return out-of-plane diffraction vectors return_orientation_matrix (bool) Return the orientation matrix Returns ---------- - bragg_peaks (PointList) + bragg_peaks (PointList) list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] - orientation_matrix (array, optional) - 3x3 orientation matrix - + orientation_matrix (array, optional) + 3x3 orientation matrix + """ if not (hasattr(self, "wavelength") and hasattr(self, "accel_voltage")): @@ -809,7 +809,7 @@ def generate_diffraction_pattern( if foil_normal is None: sg = self.excitation_errors( g, - precession_angle_degrees = precession_angle_degrees, + precession_angle_degrees=precession_angle_degrees, ) else: foil_normal = ( @@ -818,8 +818,8 @@ def generate_diffraction_pattern( ).ravel() sg = self.excitation_errors( g, - foil_normal = foil_normal, - precession_angle_degrees = precession_angle_degrees, + foil_normal=foil_normal, + precession_angle_degrees=precession_angle_degrees, ) # Threshold for inclusion in diffraction pattern @@ -827,7 +827,7 @@ def generate_diffraction_pattern( if precession_angle_degrees is None: keep = np.abs(sg) <= sg_max else: - keep = np.min(np.abs(sg),axis=1) <= sg_max + keep = np.min(np.abs(sg), axis=1) <= sg_max # Maximum scattering angle cutoff if k_max is not None: @@ -843,10 +843,8 @@ def generate_diffraction_pattern( ) else: g_int = self.struct_factors_int[keep] * np.mean( - np.exp( - (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) - ), - axis = 1, + np.exp((sg[keep] ** 2) / (-2 * sigma_excitation_error**2)), + axis=1, ) hkl = self.hkl[:, keep] @@ -1394,29 +1392,31 @@ def excitation_errors( t = np.deg2rad(precession_angle_degrees) p = np.linspace( 0, - 2.0*np.pi, + 2.0 * np.pi, precession_steps, - endpoint = False, + endpoint=False, ) if foil_normal is None: - foil_normal = np.array((0.0,0.0,-1.0)) + foil_normal = np.array((0.0, 0.0, -1.0)) k = np.reshape( - (-1/self.wavelength) * np.vstack(( - np.sin(t)*np.cos(p), - np.sin(t)*np.sin(p), - np.cos(t)*np.ones(p.size), - )), - (3,1,p.size), + (-1 / self.wavelength) + * np.vstack( + ( + np.sin(t) * np.cos(p), + np.sin(t) * np.sin(p), + np.cos(t) * np.ones(p.size), + ) + ), + (3, 1, p.size), ) - term1 = np.sum( (g[:,:,None] + k) * foil_normal[:,None,None], axis=0) - term2 = np.sum( (g[:,:,None] + 2*k) * g[:,:,None], axis=0) + term1 = np.sum((g[:, :, None] + k) * foil_normal[:, None, None], axis=0) + term2 = np.sum((g[:, :, None] + 2 * k) * g[:, :, None], axis=0) sg = np.sqrt(term1**2 - term2) - term1 return sg - def calculate_bragg_peak_histogram( self, bragg_peaks, diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 74e842817..ae71c20bc 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -28,10 +28,10 @@ def orientation_plan( accel_voltage: float = 300e3, corr_kernel_size: float = 0.08, sigma_excitation_error: float = 0.02, - precession_angle_degrees = None, + precession_angle_degrees=None, power_radial: float = 1.0, - power_intensity: float = 0.25, - power_intensity_experiment: float = 0.25, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -56,53 +56,53 @@ def orientation_plan( Setting to 'fiber' as a string will make a spherical cap around a given vector. Setting to 'auto' will use pymatgen to determine the point group symmetry of the structure and choose an appropriate zone_axis_range - angle_step_zone_axis (float): + angle_step_zone_axis (float): Approximate angular step size for zone axis search [degrees] - angle_coarse_zone_axis (float): + angle_coarse_zone_axis (float): Coarse step size for zone axis search [degrees]. Setting to None uses the same value as angle_step_zone_axis. - angle_refine_range (float): + angle_refine_range (float): Range of angles to use for zone axis refinement. Setting to None uses same value as angle_coarse_zone_axis. - angle_step_in_plane (float): + angle_step_in_plane (float): Approximate angular step size for in-plane rotation [degrees] - accel_voltage (float): + accel_voltage (float): Accelerating voltage for electrons [Volts] - corr_kernel_size (float): - Correlation kernel size length. The size of the overlap kernel between the + corr_kernel_size (float): + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] sigma_excitation_error (float): The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - - power_radial (float): + + power_radial (float): Power for scaling the correlation intensity as a function of the peak radius - power_intensity (float): + power_intensity (float): Power for scaling the correlation intensity as a function of simulated peak intensity - power_intensity_experiment (float): + power_intensity_experiment (float): Power for scaling the correlation intensity as a function of experimental peak intensity - calculate_correlation_array (bool): + calculate_correlation_array (bool): Set to false to skip calculating the correlation array. This is useful when we only want the angular range / rotation matrices. - tol_peak_delete (float): + tol_peak_delete (float): Distance to delete peaks for multiple matches. Default is kernel_size * 0.5 - tol_distance (float): + tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] - fiber_axis (float): + fiber_axis (float): (3,) vector specifying the fiber axis - fiber_angles (float): + fiber_angles (float): (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] - cartesian_directions (bool): + cartesian_directions (bool): When set to true, all zone axes and projection directions are specified in Cartesian directions. - figsize (float): + figsize (float): (2,) vector giving the figure size - CUDA (bool): + CUDA (bool): Use CUDA for the Fourier operations. - progress_bar (bool): + progress_bar (bool): If false no progress bar is displayed, """ @@ -120,8 +120,10 @@ def orientation_plan( if precession_angle_degrees is None: self.orientation_precession_angle_degrees = None else: - self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) - self.orientation_precession_angle = np.deg2rad(np.asarray(precession_angle_degrees)) + self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) + self.orientation_precession_angle = np.deg2rad( + np.asarray(precession_angle_degrees) + ) if tol_peak_delete is None: self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5 else: @@ -775,37 +777,46 @@ def orientation_plan( # calculate intensity of spots if precession_angle_degrees is None: - Ig = np.exp(sg[keep]**2/(-2*sigma_excitation_error**2)) + Ig = np.exp(sg[keep] ** 2 / (-2 * sigma_excitation_error**2)) else: # precession extension - prec = np.cos(np.linspace(0,2*np.pi,90,endpoint=False)) - dsg = np.tan(self.orientation_precession_angle) * np.sum(g[:2,keep]**2,axis=0) - Ig = np.mean(np.exp((sg[keep,None] + dsg[:,None]*prec[None,:])**2 \ - / (-2*sigma_excitation_error**2)), axis = 1) + prec = np.cos(np.linspace(0, 2 * np.pi, 90, endpoint=False)) + dsg = np.tan(self.orientation_precession_angle) * np.sum( + g[:2, keep] ** 2, axis=0 + ) + Ig = np.mean( + np.exp( + (sg[keep, None] + dsg[:, None] * prec[None, :]) ** 2 + / (-2 * sigma_excitation_error**2) + ), + axis=1, + ) # in-plane rotation angle phi = np.arctan2(g[1, keep], g[0, keep]) phi_ind = phi / self.orientation_gamma[1] # step size of annular bins - phi_floor = np.floor(phi_ind).astype('int') + phi_floor = np.floor(phi_ind).astype("int") dphi = phi_ind - phi_floor # write intensities into orientation plan slice radial_inds = self.orientation_shell_index[keep] - self.orientation_ref[a0, radial_inds, phi_floor] += \ - (1-dphi) * \ - np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ - np.power(self.orientation_shell_radii[radial_inds], power_radial) - self.orientation_ref[a0, radial_inds, np.mod(phi_floor+1,self.orientation_in_plane_steps)] += \ - dphi * \ - np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ - np.power(self.orientation_shell_radii[radial_inds], power_radial) - + self.orientation_ref[a0, radial_inds, phi_floor] += ( + (1 - dphi) + * np.power(self.struct_factors_int[keep] * Ig, power_intensity) + * np.power(self.orientation_shell_radii[radial_inds], power_radial) + ) + self.orientation_ref[ + a0, radial_inds, np.mod(phi_floor + 1, self.orientation_in_plane_steps) + ] += ( + dphi + * np.power(self.struct_factors_int[keep] * Ig, power_intensity) + * np.power(self.orientation_shell_radii[radial_inds], power_radial) + ) # # Loop over all peaks # for a1 in np.arange(self.g_vec_all.shape[1]): # if keep[a1]: - # for a1 in np.arange(self.g_vec_all.shape[1]): # ind_radial = self.orientation_shell_index[a1] @@ -1048,21 +1059,23 @@ def match_single_pattern( self.orientation_power_intensity_experiment, ) * np.exp( - (dqr[sub, None] ** 2 - + ( - ( - np.mod( - self.orientation_gamma[None, :] - - qphi[sub, None] - + np.pi, - 2 * np.pi, + ( + dqr[sub, None] ** 2 + + ( + ( + np.mod( + self.orientation_gamma[None, :] + - qphi[sub, None] + + np.pi, + 2 * np.pi, + ) + - np.pi ) - - np.pi + * radius ) - * radius + ** 2 ) - ** 2) - / (-2*self.orientation_kernel_size**2) + / (-2 * self.orientation_kernel_size**2) ), axis=0, ) @@ -1096,7 +1109,6 @@ def match_single_pattern( # axis=0, # ) - # im_polar[ind_radial, :] = np.sum( # np.power(radius, self.orientation_power_radial) # * np.power( @@ -1127,7 +1139,6 @@ def match_single_pattern( # axis=0, # ) - # normalization # im_polar -= np.mean(im_polar) @@ -1628,8 +1639,8 @@ def match_single_pattern( bragg_peaks_fit = self.generate_diffraction_pattern( orientation, ind_orientation=match_ind, - sigma_excitation_error = self.orientation_sigma_excitation_error, - precession_angle_degrees = self.orientation_precession_angle_degrees, + sigma_excitation_error=self.orientation_sigma_excitation_error, + precession_angle_degrees=self.orientation_precession_angle_degrees, ) remove = np.zeros_like(qx, dtype="bool") diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 750ea4670..3db718555 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -10,7 +10,6 @@ from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern - class Crystal_Phase: """ A class storing multiple crystal structures, and associated diffraction data. @@ -21,9 +20,9 @@ class Crystal_Phase: def __init__( self, crystals, - crystal_names = None, - orientation_maps = None, - name = None, + crystal_names=None, + orientation_maps=None, + name=None, ): """ Args: @@ -39,7 +38,9 @@ def __init__( # List of orientation maps if orientation_maps is None: - self.orientation_maps = [crystals[ind].orientation_map for ind in range(self.num_crystals)] + self.orientation_maps = [ + crystals[ind].orientation_map for ind in range(self.num_crystals) + ] else: if len(self.crystals) != len(orientation_maps): raise ValueError( @@ -49,60 +50,64 @@ def __init__( # Names of all crystal phases if crystal_names is None: - self.crystal_names = ['crystal' + str(ind) for ind in range(self.num_crystals)] + self.crystal_names = [ + "crystal" + str(ind) for ind in range(self.num_crystals) + ] else: self.crystal_names = crystal_names # Name of the phase map if name is None: - self.name = 'phase map' + self.name = "phase map" else: self.name = name # Get some attributes from crystals self.k_max = np.zeros(self.num_crystals) - self.num_matches = np.zeros(self.num_crystals, dtype='int') - self.crystal_identity = np.zeros((0,2), dtype='int') + self.num_matches = np.zeros(self.num_crystals, dtype="int") + self.crystal_identity = np.zeros((0, 2), dtype="int") for a0 in range(self.num_crystals): self.k_max[a0] = self.crystals[a0].k_max self.num_matches[a0] = self.crystals[a0].orientation_map.num_matches for a1 in range(self.num_matches[a0]): - self.crystal_identity = np.append(self.crystal_identity,np.array((a0,a1),dtype='int')[None,:], axis=0) + self.crystal_identity = np.append( + self.crystal_identity, + np.array((a0, a1), dtype="int")[None, :], + axis=0, + ) self.num_fits = np.sum(self.num_matches) - - def quantify_single_pattern( self, pointlistarray: PointListArray, - xy_position = (0,0), - corr_kernel_size = 0.04, + xy_position=(0, 0), + corr_kernel_size=0.04, sigma_excitation_error: float = 0.02, - precession_angle_degrees = None, - power_intensity: float = 0.25, - power_intensity_experiment: float = 0.25, - k_max = None, - max_number_patterns = 2, - single_phase = False, - allow_strain = False, - strain_iterations = 3, - strain_max = 0.02, - include_false_positives = True, - weight_false_positives = 1.0, - weight_unmatched_peaks = 1.0, - plot_result = True, - plot_only_nonzero_phases = True, - plot_unmatched_peaks = False, - plot_correlation_radius = False, - scale_markers_experiment = 40, - scale_markers_calculated = 200, - crystal_inds_plot = None, - phase_colors = None, - figsize = (10,7), - verbose = True, - returnfig = False, - ): + precession_angle_degrees=None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max=None, + max_number_patterns=2, + single_phase=False, + allow_strain=False, + strain_iterations=3, + strain_max=0.02, + include_false_positives=True, + weight_false_positives=1.0, + weight_unmatched_peaks=1.0, + plot_result=True, + plot_only_nonzero_phases=True, + plot_unmatched_peaks=False, + plot_correlation_radius=False, + scale_markers_experiment=40, + scale_markers_calculated=200, + crystal_inds_plot=None, + phase_colors=None, + figsize=(10, 7), + verbose=True, + returnfig=False, + ): """ Quantify the phase for a single diffraction pattern. @@ -110,21 +115,21 @@ def quantify_single_pattern( Parameters ---------- - + pointlistarray: (PointListArray) Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) xy_position: (int,int) The (x,y) or (row,column) position to be quantified. - corr_kernel_size: (float) - Correlation kernel size length. The size of the overlap kernel between the + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] sigma_excitation_error: (float) The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees: (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - power_intensity: (float) + power_intensity: (float) Power for scaling the correlation intensity as a function of simulated peak intensity. - power_intensity_experiment: (float): + power_intensity_experiment: (float): Power for scaling the correlation intensity as a function of experimental peak intensity. k_max: (float) Max k values included in fits, for both x and y directions. @@ -171,7 +176,7 @@ def quantify_single_pattern( Returns ------- phase_weights: (np.array) - Estimated relative fraction of each phase for all probe positions. + Estimated relative fraction of each phase for all probe positions. shape = (num_x, num_y, num_orientations) where num_orientations is the total number of all orientations for all phases. phase_residual: (np.array) @@ -193,45 +198,49 @@ def quantify_single_pattern( tol2 = 1e-6 # calibrations - center = pointlistarray.calstate['center'] - ellipse = pointlistarray.calstate['ellipse'] - pixel = pointlistarray.calstate['pixel'] - rotate = pointlistarray.calstate['rotate'] + center = pointlistarray.calstate["center"] + ellipse = pointlistarray.calstate["ellipse"] + pixel = pointlistarray.calstate["pixel"] + rotate = pointlistarray.calstate["rotate"] if center is False: - raise ValueError('Bragg peaks must be center calibration') + raise ValueError("Bragg peaks must be center calibration") if pixel is False: - raise ValueError('Bragg peaks must have pixel size calibration') + raise ValueError("Bragg peaks must have pixel size calibration") # TODO - potentially warn the user if ellipse / rotate calibration not available if phase_colors is None: - phase_colors = np.array(( - (1.0,0.0,0.0,1.0), - (0.0,0.8,1.0,1.0), - (0.0,0.6,0.0,1.0), - (1.0,0.0,1.0,1.0), - (0.0,0.2,1.0,1.0), - (1.0,0.8,0.0,1.0), - )) + phase_colors = np.array( + ( + (1.0, 0.0, 0.0, 1.0), + (0.0, 0.8, 1.0, 1.0), + (0.0, 0.6, 0.0, 1.0), + (1.0, 0.0, 1.0, 1.0), + (0.0, 0.2, 1.0, 1.0), + (1.0, 0.8, 0.0, 1.0), + ) + ) # Experimental values bragg_peaks = pointlistarray.get_vectors( xy_position[0], xy_position[1], - center = center, - ellipse = ellipse, - pixel = pixel, - rotate = rotate) + center=center, + ellipse=ellipse, + pixel=pixel, + rotate=rotate, + ) # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() if k_max is None: - keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + keep = bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2 else: - keep = np.logical_and.reduce(( - bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2, - np.abs(bragg_peaks.data["qx"]) < k_max, - np.abs(bragg_peaks.data["qy"]) < k_max, - )) + keep = np.logical_and.reduce( + ( + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2, + np.abs(bragg_peaks.data["qx"]) < k_max, + np.abs(bragg_peaks.data["qy"]) < k_max, + ) + ) - # ind_center_beam = np.argmin( # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') @@ -245,9 +254,14 @@ def quantify_single_pattern( intensity = np.ones_like(qx) intensity0 = np.ones_like(qx0) else: - intensity = bragg_peaks.data["intensity"][keep]**power_intensity_experiment - intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_intensity_experiment - int_total = np.sum(intensity) + intensity = ( + bragg_peaks.data["intensity"][keep] ** power_intensity_experiment + ) + intensity0 = ( + bragg_peaks.data["intensity"][np.logical_not(keep)] + ** power_intensity_experiment + ) + int_total = np.sum(intensity) # init basis array if include_false_positives: @@ -256,9 +270,9 @@ def quantify_single_pattern( else: basis = np.zeros((intensity.shape[0], self.num_fits)) if allow_strain: - m_strains = np.zeros((self.num_fits,2,2)) - m_strains[:,0,0] = 1.0 - m_strains[:,1,1] = 1.0 + m_strains = np.zeros((self.num_fits, 2, 2)) + m_strains[:, 0, 0] = 1.0 + m_strains[:, 1, 1] = 1.0 # kernel radius squared radius_max_2 = corr_kernel_size**2 @@ -271,8 +285,8 @@ def quantify_single_pattern( # Generate point list data, match to experimental peaks for a0 in range(self.num_fits): - c = self.crystal_identity[a0,0] - m = self.crystal_identity[a0,1] + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] # for c in range(self.num_crystals): # for m in range(self.num_matches[c]): # ind_match += 1 @@ -282,40 +296,46 @@ def quantify_single_pattern( self.crystals[c].orientation_map.get_orientation( xy_position[0], xy_position[1] ), - ind_orientation = m, - sigma_excitation_error = sigma_excitation_error, - precession_angle_degrees = precession_angle_degrees, + ind_orientation=m, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, ) if k_max is None: - del_peak = bragg_peaks_fit.data["qx"]**2 \ - + bragg_peaks_fit.data["qy"]**2 < tol2 + del_peak = ( + bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2 + < tol2 + ) else: - del_peak = np.logical_or.reduce(( - bragg_peaks_fit.data["qx"]**2 \ - + bragg_peaks_fit.data["qy"]**2 < tol2, - np.abs(bragg_peaks_fit.data["qx"]) > k_max, - np.abs(bragg_peaks_fit.data["qy"]) > k_max, - )) + del_peak = np.logical_or.reduce( + ( + bragg_peaks_fit.data["qx"] ** 2 + + bragg_peaks_fit.data["qy"] ** 2 + < tol2, + np.abs(bragg_peaks_fit.data["qx"]) > k_max, + np.abs(bragg_peaks_fit.data["qy"]) > k_max, + ) + ) bragg_peaks_fit.remove(del_peak) # peak intensities if power_intensity == 0: int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) else: - int_fit = bragg_peaks_fit.data['intensity']**power_intensity - + int_fit = bragg_peaks_fit.data["intensity"] ** power_intensity + # Pair peaks to experiment if plot_result: - matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') + matches = np.zeros((bragg_peaks_fit.data.shape[0]), dtype="bool") if allow_strain: for a1 in range(strain_iterations): # Initial peak pairing to find best-fit strain distortion - pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') - pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') + pair_sub = np.zeros(bragg_peaks_fit.data.shape[0], dtype="bool") + pair_inds = np.zeros(bragg_peaks_fit.data.shape[0], dtype="int") for a1 in range(bragg_peaks_fit.data.shape[0]): - dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ - + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + dist2 = (bragg_peaks_fit.data["qx"][a1] - qx) ** 2 + ( + bragg_peaks_fit.data["qy"][a1] - qy + ) ** 2 ind_min = np.argmin(dist2) val_min = dist2[ind_min] @@ -327,19 +347,32 @@ def quantify_single_pattern( # requires at least 4 peak pairs if np.sum(pair_sub) >= 4: # pair_obs = bragg_peaks_fit.data[['qx','qy']][pair_sub] - pair_basis = np.vstack(( - bragg_peaks_fit.data['qx'][pair_sub], - bragg_peaks_fit.data['qy'][pair_sub], - )).T - pair_obs = np.vstack(( - qx[pair_inds[pair_sub]], - qy[pair_inds[pair_sub]], - )).T + pair_basis = np.vstack( + ( + bragg_peaks_fit.data["qx"][pair_sub], + bragg_peaks_fit.data["qy"][pair_sub], + ) + ).T + pair_obs = np.vstack( + ( + qx[pair_inds[pair_sub]], + qy[pair_inds[pair_sub]], + ) + ).T # weights dists = np.sqrt( - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) + ( + bragg_peaks_fit.data["qx"][pair_sub] + - qx[pair_inds[pair_sub]] + ) + ** 2 + + ( + bragg_peaks_fit.data["qx"][pair_sub] + - qx[pair_inds[pair_sub]] + ) + ** 2 + ) weights = np.sqrt( int_fit[pair_sub] * intensity[pair_inds[pair_sub]] ) * (1 - dists / corr_kernel_size) @@ -347,9 +380,9 @@ def quantify_single_pattern( # strain tensor m_strain = np.linalg.lstsq( - pair_basis * weights[:,None], - pair_obs * weights[:,None], - rcond = None, + pair_basis * weights[:, None], + pair_obs * weights[:, None], + rcond=None, )[0] # Clamp strains to be within the user-specified limit @@ -361,34 +394,42 @@ def quantify_single_pattern( m_strains[a0] *= m_strain # Transformed peak positions - qx_copy = bragg_peaks_fit.data['qx'] - qy_copy = bragg_peaks_fit.data['qy'] - bragg_peaks_fit.data['qx'] = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - bragg_peaks_fit.data['qy'] = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + qx_copy = bragg_peaks_fit.data["qx"] + qy_copy = bragg_peaks_fit.data["qy"] + bragg_peaks_fit.data["qx"] = ( + qx_copy * m_strain[0, 0] + qy_copy * m_strain[1, 0] + ) + bragg_peaks_fit.data["qy"] = ( + qx_copy * m_strain[0, 1] + qy_copy * m_strain[1, 1] + ) # Loop over all peaks, pair experiment to library for a1 in range(bragg_peaks_fit.data.shape[0]): - dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ - + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + dist2 = (bragg_peaks_fit.data["qx"][a1] - qx) ** 2 + ( + bragg_peaks_fit.data["qy"][a1] - qy + ) ** 2 ind_min = np.argmin(dist2) val_min = dist2[ind_min] if include_false_positives: - weight = np.clip(1 - np.sqrt(dist2[ind_min]) / corr_kernel_size,0,1) - basis[ind_min,a0] = int_fit[a1] * weight - unpaired_peaks.append([ - a0, - int_fit[a1] * (1 - weight), - ]) + weight = np.clip( + 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size, 0, 1 + ) + basis[ind_min, a0] = int_fit[a1] * weight + unpaired_peaks.append( + [ + a0, + int_fit[a1] * (1 - weight), + ] + ) if weight > 1e-8 and plot_result: matches[a1] = True else: if val_min < radius_max_2: - basis[ind_min,a0] = int_fit[a1] + basis[ind_min, a0] = int_fit[a1] if plot_result: matches[a1] = True - # if val_min < radius_max_2: # # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size # # weight = 1 + corr_distance_scale * \ @@ -402,22 +443,22 @@ def quantify_single_pattern( # unpaired_peaks.append([a0,int_fit[a1]]) if plot_result: - library_peaks.append(bragg_peaks_fit) + library_peaks.append(bragg_peaks_fit) library_int.append(int_fit) library_matches.append(matches) # If needed, augment basis and observations with false positives if include_false_positives: - basis_aug = np.zeros((len(unpaired_peaks),self.num_fits)) + basis_aug = np.zeros((len(unpaired_peaks), self.num_fits)) for a0 in range(len(unpaired_peaks)): - basis_aug[a0,unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] + basis_aug[a0, unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] basis = np.vstack((basis, basis_aug * weight_false_positives)) obs = np.hstack((intensity, np.zeros(len(unpaired_peaks)))) else: obs = intensity - + # Solve for phase weight coefficients try: phase_weights = np.zeros(self.num_fits) @@ -428,16 +469,16 @@ def quantify_single_pattern( crystal_res = np.zeros(self.num_crystals) for a0 in range(self.num_crystals): - inds_solve = self.crystal_identity[:,0] == a0 + inds_solve = self.crystal_identity[:, 0] == a0 search = True while search is True: - basis_solve = basis[:,inds_solve] + basis_solve = basis[:, inds_solve] obs_solve = obs.copy() if weight_unmatched_peaks > 1.0: - sub_unmatched = np.sum(basis_solve,axis=1)<1e-8 + sub_unmatched = np.sum(basis_solve, axis=1) < 1e-8 obs_solve[sub_unmatched] *= weight_unmatched_peaks phase_weights_cand, phase_residual_cand = nnls( @@ -445,7 +486,10 @@ def quantify_single_pattern( obs_solve, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): phase_weights[inds_solve] = phase_weights_cand crystal_res[a0] = phase_residual_cand search = False @@ -457,9 +501,7 @@ def quantify_single_pattern( # ind_best_fit = np.argmax(phase_weights) phase_residual = crystal_res[ind_best_fit] - sub = np.logical_not( - self.crystal_identity[:,0] == ind_best_fit - ) + sub = np.logical_not(self.crystal_identity[:, 0] == ind_best_fit) phase_weights[sub] = 0.0 # Estimate reliability as difference between best fit and 2nd best fit @@ -468,15 +510,18 @@ def quantify_single_pattern( else: # Allow all crystals and orientation matches in the pattern - inds_solve = np.ones(self.num_fits,dtype='bool') + inds_solve = np.ones(self.num_fits, dtype="bool") search = True while search is True: phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], + basis[:, inds_solve], obs, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): phase_weights[inds_solve] = phase_weights_cand phase_residual = phase_residual_cand search = False @@ -484,9 +529,8 @@ def quantify_single_pattern( inds = np.where(inds_solve)[0] inds_solve[inds[np.argmin(phase_weights_cand)]] = False - # Estimate the phase reliability - inds_solve = np.ones(self.num_fits,dtype='bool') + inds_solve = np.ones(self.num_fits, dtype="bool") inds_solve[phase_weights > 1e-8] = False if np.all(inds_solve == False): @@ -495,10 +539,13 @@ def quantify_single_pattern( search = True while search is True: phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], + basis[:, inds_solve], obs, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): phase_residual_2nd = phase_residual_cand search = False else: @@ -506,7 +553,7 @@ def quantify_single_pattern( inds_solve[inds[np.argmin(phase_weights_cand)]] = False phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], + basis[:, inds_solve], obs, ) phase_reliability = phase_residual_2nd - phase_residual @@ -516,34 +563,25 @@ def quantify_single_pattern( phase_residual = np.sqrt(np.sum(intensity**2)) phase_reliability = 0.0 - if verbose: ind_max = np.argmax(phase_weights) # print() - print('\033[1m' + 'phase_weight or_ind name' + '\033[0m') + print("\033[1m" + "phase_weight or_ind name" + "\033[0m") # print() for a0 in range(self.num_fits): - c = self.crystal_identity[a0,0] - m = self.crystal_identity[a0,1] - line = '{:>12} {:>8} {:<12}'.format( - f'{phase_weights[a0]:.2f}', - m, - self.crystal_names[c] - ) + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] + line = "{:>12} {:>8} {:<12}".format( + f"{phase_weights[a0]:.2f}", m, self.crystal_names[c] + ) if a0 == ind_max: - print('\033[1m' + line + '\033[0m') + print("\033[1m" + line + "\033[0m") else: print(line) - print('----------------------------') - line = '{:>12} {:>15}'.format( - f'{sum(phase_weights):.2f}', - 'fit total' - ) - print('\033[1m' + line + '\033[0m') - line = '{:>12} {:>15}'.format( - f'{phase_residual:.2f}', - 'fit residual' - ) + print("----------------------------") + line = "{:>12} {:>15}".format(f"{sum(phase_weights):.2f}", "fit total") + print("\033[1m" + line + "\033[0m") + line = "{:>12} {:>15}".format(f"{phase_residual:.2f}", "fit residual") print(line) # Plotting @@ -558,39 +596,40 @@ def quantify_single_pattern( if plot_correlation_radius: # plot the experimental radii - t = np.linspace(0,2*np.pi,91,endpoint=True) + t = np.linspace(0, 2 * np.pi, 91, endpoint=True) ct = np.cos(t) * corr_kernel_size st = np.sin(t) * corr_kernel_size for a0 in range(qx.shape[0]): ax.plot( qy[a0] + st, qx[a0] + ct, - color = 'k', - linewidth = 1, - ) + color="k", + linewidth=1, + ) # plot the experimental peaks ax.scatter( qy0, qx0, # s = scale_markers_experiment * intensity0, - s = scale_markers_experiment * bragg_peaks.data["intensity"][np.logical_not(keep)], - marker = "o", - facecolor = [0.7, 0.7, 0.7], - ) + s=scale_markers_experiment + * bragg_peaks.data["intensity"][np.logical_not(keep)], + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) ax.scatter( qy, qx, # s = scale_markers_experiment * intensity, - s = scale_markers_experiment * bragg_peaks.data["intensity"][keep], - marker = "o", - facecolor = [0.7, 0.7, 0.7], - ) + s=scale_markers_experiment * bragg_peaks.data["intensity"][keep], + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) # legend if k_max is None: k_max = np.max(self.k_max) - dx_leg = -0.05*k_max - dy_leg = 0.04*k_max + dx_leg = -0.05 * k_max + dy_leg = 0.04 * k_max text_params = { "va": "center", "ha": "left", @@ -601,34 +640,25 @@ def quantify_single_pattern( } if plot_correlation_radius: ax_leg.plot( - 0 + st*0.5, - -dx_leg + ct*0.5, - color = 'k', - linewidth = 1, - ) + 0 + st * 0.5, + -dx_leg + ct * 0.5, + color="k", + linewidth=1, + ) ax_leg.scatter( 0, 0, - s = 200, - marker = "o", - facecolor = [0.7, 0.7, 0.7], - ) - ax_leg.text( - dy_leg, - 0, - 'Experimental peaks', - **text_params) + s=200, + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) + ax_leg.text(dy_leg, 0, "Experimental peaks", **text_params) if plot_correlation_radius: - ax_leg.text( - dy_leg, - -dx_leg, - 'Correlation radius', - **text_params) - + ax_leg.text(dy_leg, -dx_leg, "Correlation radius", **text_params) # plot calculated diffraction patterns uvals = phase_colors.copy() - uvals[:,3] = 0.3 + uvals[:, 3] = 0.3 # uvals = np.array(( # (1.0,0.0,0.0,0.2), # (0.0,0.8,1.0,0.2), @@ -637,25 +667,35 @@ def quantify_single_pattern( # (0.0,0.2,1.0,0.2), # (1.0,0.8,0.0,0.2), # )) - mvals = ['v','^','<','>','d','s',] + mvals = [ + "v", + "^", + "<", + ">", + "d", + "s", + ] count_leg = 0 for a0 in range(self.num_fits): - c = self.crystal_identity[a0,0] - m = self.crystal_identity[a0,1] + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] - if crystal_inds_plot == None or np.min(np.abs(c - crystal_inds_plot)) == 0: + if ( + crystal_inds_plot == None + or np.min(np.abs(c - crystal_inds_plot)) == 0 + ): - qx_fit = library_peaks[a0].data['qx'] - qy_fit = library_peaks[a0].data['qy'] + qx_fit = library_peaks[a0].data["qx"] + qy_fit = library_peaks[a0].data["qy"] if allow_strain: m_strain = m_strains[a0] # Transformed peak positions qx_copy = qx_fit.copy() qy_copy = qy_fit.copy() - qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + qx_fit = qx_copy * m_strain[0, 0] + qy_copy * m_strain[1, 0] + qy_fit = qx_copy * m_strain[0, 1] + qy_copy * m_strain[1, 1] int_fit = library_int[a0] matches_fit = library_matches[a0] @@ -666,33 +706,35 @@ def quantify_single_pattern( ax.scatter( qy_fit[matches_fit], qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) + s=scale_markers_calculated * int_fit[matches_fit], + marker=mvals[c], + facecolor=phase_colors[c, :], + ) if plot_unmatched_peaks: ax.scatter( qy_fit[np.logical_not(matches_fit)], qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) + s=scale_markers_calculated + * int_fit[np.logical_not(matches_fit)], + marker=mvals[c], + facecolor=phase_colors[c, :], + ) # legend if m == 0: ax_leg.text( dy_leg, - (count_leg+1)*dx_leg, + (count_leg + 1) * dx_leg, self.crystal_names[c], - **text_params) + **text_params, + ) ax_leg.scatter( 0, - (count_leg+1) * dx_leg, - s = 200, - marker = mvals[c], - facecolor = phase_colors[c,:], - ) + (count_leg + 1) * dx_leg, + s=200, + marker=mvals[c], + facecolor=phase_colors[c, :], + ) count_leg += 1 # else: # ax.scatter( @@ -727,14 +769,12 @@ def quantify_single_pattern( # # facecolors = (1,1,1,0.5), # ) - - # appearance ax.set_xlim((-k_max, k_max)) ax.set_ylim((k_max, -k_max)) - ax_leg.set_xlim((-0.1*k_max, 0.4*k_max)) - ax_leg.set_ylim((-0.5*k_max, 0.5*k_max)) + ax_leg.set_xlim((-0.1 * k_max, 0.4 * k_max)) + ax_leg.set_ylim((-0.5 * k_max, 0.5 * k_max)) ax_leg.set_axis_off() if returnfig: @@ -745,21 +785,21 @@ def quantify_single_pattern( def quantify_phase( self, pointlistarray: PointListArray, - corr_kernel_size = 0.04, - sigma_excitation_error = 0.02, - precession_angle_degrees = None, - power_intensity: float = 0.25, - power_intensity_experiment: float = 0.25, - k_max = None, - max_number_patterns = 2, - single_phase = False, - allow_strain = True, - strain_iterations = 3, - strain_max = 0.02, - include_false_positives = True, - weight_false_positives = 1.0, - progress_bar = True, - ): + corr_kernel_size=0.04, + sigma_excitation_error=0.02, + precession_angle_degrees=None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max=None, + max_number_patterns=2, + single_phase=False, + allow_strain=True, + strain_iterations=3, + strain_max=0.02, + include_false_positives=True, + weight_false_positives=1.0, + progress_bar=True, + ): """ Quantify phase of all diffraction patterns. @@ -767,16 +807,16 @@ def quantify_phase( ---------- pointlistarray: (PointListArray) Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) - corr_kernel_size: (float) - Correlation kernel size length. The size of the overlap kernel between the + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] sigma_excitation_error: (float) The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees: (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - power_intensity: (float) + power_intensity: (float) Power for scaling the correlation intensity as a function of simulated peak intensity. - power_intensity_experiment: (float): + power_intensity_experiment: (float): Power for scaling the correlation intensity as a function of experimental peak intensity. k_max: (float) Max k values included in fits, for both x and y directions. @@ -803,72 +843,79 @@ def quantify_phase( """ # init results arrays - self.phase_weights = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - self.num_fits, - )) - self.phase_residuals = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - )) - self.phase_reliability = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - )) - self.int_total = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - )) + self.phase_weights = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + self.num_fits, + ) + ) + self.phase_residuals = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) + self.phase_reliability = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) + self.int_total = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) self.single_phase = single_phase - + for rx, ry in tqdmnd( - *pointlistarray.shape, - desc="Quantifying Phase", - unit=" PointList", - disable=not progress_bar, - ): + *pointlistarray.shape, + desc="Quantifying Phase", + unit=" PointList", + disable=not progress_bar, + ): # calculate phase weights - phase_weights, phase_residual, phase_reliability, int_peaks = self.quantify_single_pattern( - pointlistarray = pointlistarray, - xy_position = (rx,ry), - corr_kernel_size = corr_kernel_size, - sigma_excitation_error = sigma_excitation_error, - - precession_angle_degrees = precession_angle_degrees, - power_intensity = power_intensity, - power_intensity_experiment = power_intensity_experiment, - k_max = k_max, - - max_number_patterns = max_number_patterns, - single_phase = single_phase, - allow_strain = allow_strain, - strain_iterations = strain_iterations, - strain_max = strain_max, - include_false_positives = include_false_positives, - weight_false_positives = weight_false_positives, - plot_result = False, - verbose = False, - returnfig = False, + phase_weights, phase_residual, phase_reliability, int_peaks = ( + self.quantify_single_pattern( + pointlistarray=pointlistarray, + xy_position=(rx, ry), + corr_kernel_size=corr_kernel_size, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, + power_intensity=power_intensity, + power_intensity_experiment=power_intensity_experiment, + k_max=k_max, + max_number_patterns=max_number_patterns, + single_phase=single_phase, + allow_strain=allow_strain, + strain_iterations=strain_iterations, + strain_max=strain_max, + include_false_positives=include_false_positives, + weight_false_positives=weight_false_positives, + plot_result=False, + verbose=False, + returnfig=False, ) - self.phase_weights[rx,ry] = phase_weights - self.phase_residuals[rx,ry] = phase_residual - self.phase_reliability[rx,ry] = phase_reliability - self.int_total[rx,ry] = int_peaks - + ) + self.phase_weights[rx, ry] = phase_weights + self.phase_residuals[rx, ry] = phase_residual + self.phase_reliability[rx, ry] = phase_reliability + self.int_total[rx, ry] = int_peaks def plot_phase_weights( self, - weight_range = (0.0,1.0), - weight_normalize = False, - total_intensity_normalize = True, - cmap = 'gray', - show_ticks = False, - show_axes = True, - layout = 0, - figsize = (6,6), - returnfig = False, - ): + weight_range=(0.0, 1.0), + weight_normalize=False, + total_intensity_normalize=True, + cmap="gray", + show_ticks=False, + show_axes=True, + layout=0, + figsize=(6, 6), + returnfig=False, + ): """ Plot the individual phase weight maps and residuals. @@ -906,43 +953,45 @@ def plot_phase_weights( if total_intensity_normalize: sub = self.int_total > 0.0 for a0 in range(self.num_fits): - phase_weights[:,:,a0][sub] /= self.int_total[sub] + phase_weights[:, :, a0][sub] /= self.int_total[sub] phase_residuals[sub] /= self.int_total[sub] # intensity range for plotting if weight_normalize: - scale = np.median(np.max(phase_weights,axis=2)) + scale = np.median(np.max(phase_weights, axis=2)) else: scale = 1 weight_range = np.array(weight_range) * scale # plotting if layout == 0: - fig,ax = plt.subplots( + fig, ax = plt.subplots( 1, self.num_crystals + 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) elif layout == 1: - fig,ax = plt.subplots( + fig, ax = plt.subplots( self.num_crystals + 1, 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) for a0 in range(self.num_crystals): - sub = self.crystal_identity[:,0] == a0 - im = np.sum(phase_weights[:,:,sub],axis=2) + sub = self.crystal_identity[:, 0] == a0 + im = np.sum(phase_weights[:, :, sub], axis=2) im = np.clip( - (im - weight_range[0]) / (weight_range[1] - weight_range[0]), - 0,1) + (im - weight_range[0]) / (weight_range[1] - weight_range[0]), 0, 1 + ) ax[a0].imshow( im, - vmin = 0, - vmax = 1, - cmap = cmap, + vmin=0, + vmax=1, + cmap=cmap, ) ax[a0].set_title( self.crystal_names[a0], - fontsize = 16, + fontsize=16, ) if not show_ticks: ax[a0].set_xticks([]) @@ -952,18 +1001,19 @@ def plot_phase_weights( # plot residuals im = np.clip( - (phase_residuals - weight_range[0]) \ - / (weight_range[1] - weight_range[0]), - 0,1) + (phase_residuals - weight_range[0]) / (weight_range[1] - weight_range[0]), + 0, + 1, + ) ax[self.num_crystals].imshow( im, - vmin = 0, - vmax = 1, - cmap = cmap, + vmin=0, + vmax=1, + cmap=cmap, ) ax[self.num_crystals].set_title( - 'Residuals', - fontsize = 16, + "Residuals", + fontsize=16, ) if not show_ticks: ax[self.num_crystals].set_xticks([]) @@ -974,23 +1024,22 @@ def plot_phase_weights( if returnfig: return fig, ax - def plot_phase_maps( self, - weight_threshold = 0.5, - weight_normalize = True, - total_intensity_normalize = True, - plot_combine = False, - crystal_inds_plot = None, - phase_colors = None, - show_ticks = False, - show_axes = True, - layout = 0, - figsize = (6,6), - return_phase_estimate = False, - return_rgb_images = False, - returnfig = False, - ): + weight_threshold=0.5, + weight_normalize=True, + total_intensity_normalize=True, + plot_combine=False, + crystal_inds_plot=None, + phase_colors=None, + show_ticks=False, + show_axes=True, + layout=0, + figsize=(6, 6), + return_phase_estimate=False, + return_rgb_images=False, + returnfig=False, + ): """ Plot the individual phase weight maps and residuals. @@ -1035,69 +1084,77 @@ def plot_phase_maps( """ if phase_colors is None: - phase_colors = np.array(( - (1.0,0.0,0.0), - (0.0,0.8,1.0), - (0.0,0.8,0.0), - (1.0,0.0,1.0), - (0.0,0.4,1.0), - (1.0,0.8,0.0), - )) + phase_colors = np.array( + ( + (1.0, 0.0, 0.0), + (0.0, 0.8, 1.0), + (0.0, 0.8, 0.0), + (1.0, 0.0, 1.0), + (0.0, 0.4, 1.0), + (1.0, 0.8, 0.0), + ) + ) phase_weights = self.phase_weights.copy() if total_intensity_normalize: sub = self.int_total > 0.0 for a0 in range(self.num_fits): - phase_weights[:,:,a0][sub] /= self.int_total[sub] + phase_weights[:, :, a0][sub] /= self.int_total[sub] # intensity range for plotting if weight_normalize: - scale = np.median(np.max(phase_weights,axis=2)) + scale = np.median(np.max(phase_weights, axis=2)) else: scale = 1 weight_threshold = weight_threshold * scale # init - im_all = np.zeros(( - self.num_crystals, - self.phase_weights.shape[0], - self.phase_weights.shape[1])) - im_rgb_all = np.zeros(( - self.num_crystals, - self.phase_weights.shape[0], - self.phase_weights.shape[1], - 3)) + im_all = np.zeros( + ( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + ) + ) + im_rgb_all = np.zeros( + ( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + 3, + ) + ) # phase weights over threshold for a0 in range(self.num_crystals): - sub = self.crystal_identity[:,0] == a0 - im = np.sum(phase_weights[:,:,sub],axis=2) + sub = self.crystal_identity[:, 0] == a0 + im = np.sum(phase_weights[:, :, sub], axis=2) im_all[a0] = np.maximum(im - weight_threshold, 0) # estimate compositions - im_sum = np.sum(im_all, axis = 0) + im_sum = np.sum(im_all, axis=0) sub = im_sum > 0.0 for a0 in range(self.num_crystals): im_all[a0][sub] /= im_sum[sub] for a1 in range(3): - im_rgb_all[a0,:,:,a1] = im_all[a0] * phase_colors[a0,a1] + im_rgb_all[a0, :, :, a1] = im_all[a0] * phase_colors[a0, a1] if plot_combine: if crystal_inds_plot is None: - im_rgb = np.sum(im_rgb_all, axis = 0) + im_rgb = np.sum(im_rgb_all, axis=0) else: - im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis = 0) + im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis=0) - im_rgb = np.clip(im_rgb,0,1) + im_rgb = np.clip(im_rgb, 0, 1) - fig,ax = plt.subplots(1,1,figsize=figsize) + fig, ax = plt.subplots(1, 1, figsize=figsize) ax.imshow( im_rgb, ) ax.set_title( - 'Phase Maps', - fontsize = 16, + "Phase Maps", + fontsize=16, ) if not show_ticks: ax.set_xticks([]) @@ -1108,24 +1165,26 @@ def plot_phase_maps( else: # plotting if layout == 0: - fig,ax = plt.subplots( + fig, ax = plt.subplots( 1, self.num_crystals, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) elif layout == 1: - fig,ax = plt.subplots( + fig, ax = plt.subplots( self.num_crystals, 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) - + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) + for a0 in range(self.num_crystals): - + ax[a0].imshow( im_rgb_all[a0], ) ax[a0].set_title( self.crystal_names[a0], - fontsize = 16, + fontsize=16, ) if not show_ticks: ax[a0].set_xticks([]) @@ -1154,20 +1213,19 @@ def plot_phase_maps( if returnfig: return fig, ax - def plot_dominant_phase( self, - use_correlation_scores = False, - reliability_range = (0.0,1.0), - sigma = 0.0, - phase_colors = None, - ticks = True, - figsize = (6,6), - legend_add = True, - legend_fraction = 0.2, - print_fractions = False, - returnfig = True, - ): + use_correlation_scores=False, + reliability_range=(0.0, 1.0), + sigma=0.0, + phase_colors=None, + ticks=True, + figsize=(6, 6), + legend_add=True, + legend_fraction=0.2, + print_fractions=False, + returnfig=True, + ): """ Plot a combined figure showing the primary phase at each probe position. Mask by the reliability index (best match minus 2nd best match). @@ -1201,148 +1259,142 @@ def plot_dominant_phase( Figure and axes handles. """ - + if phase_colors is None: - phase_colors = np.array([ - [1.0,0.9,0.6], - [1,0,0], - [0,0.7,0], - [0,0.7,1], - [1,0,1], - ]) - + phase_colors = np.array( + [ + [1.0, 0.9, 0.6], + [1, 0, 0], + [0, 0.7, 0], + [0, 0.7, 1], + [1, 0, 1], + ] + ) # init arrays scan_shape = self.phase_weights.shape[:2] phase_map = np.zeros(scan_shape) phase_corr = np.zeros(scan_shape) phase_corr_2nd = np.zeros(scan_shape) - phase_sig = np.zeros((self.num_crystals,scan_shape[0],scan_shape[1])) + phase_sig = np.zeros((self.num_crystals, scan_shape[0], scan_shape[1])) if use_correlation_scores: # Calculate scores from highest correlation match for a0 in range(self.num_crystals): phase_sig[a0] = np.maximum( phase_sig[a0], - np.max(self.crystals[a0].orientation_map.corr,axis=2), + np.max(self.crystals[a0].orientation_map.corr, axis=2), ) else: # sum up phase weights by crystal type for a0 in range(self.num_fits): - ind = self.crystal_identity[a0,0] - phase_sig[ind] += self.phase_weights[:,:,a0] + ind = self.crystal_identity[a0, 0] + phase_sig[ind] += self.phase_weights[:, :, a0] # smoothing of the outputs if sigma > 0.0: for a0 in range(self.num_crystals): phase_sig[a0] = gaussian_filter( phase_sig[a0], - sigma = sigma, - mode = 'nearest', - ) + sigma=sigma, + mode="nearest", + ) # find highest correlation score for each crystal and match index for a0 in range(self.num_crystals): - sub = phase_sig[a0] > phase_corr + sub = phase_sig[a0] > phase_corr phase_map[sub] = a0 phase_corr[sub] = phase_sig[a0][sub] - + if self.single_phase: phase_scale = np.clip( - (self.phase_reliability - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + (self.phase_reliability - reliability_range[0]) + / (reliability_range[1] - reliability_range[0]), 0, - 1) + 1, + ) else: # find the second correlation score for each crystal and match index for a0 in range(self.num_crystals): corr = phase_sig[a0].copy() - corr[phase_map==a0] = 0.0 + corr[phase_map == a0] = 0.0 sub = corr > phase_corr_2nd phase_corr_2nd[sub] = corr[sub] - + # Estimate the reliability phase_rel = phase_corr - phase_corr_2nd phase_scale = np.clip( - (phase_rel - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + (phase_rel - reliability_range[0]) + / (reliability_range[1] - reliability_range[0]), 0, - 1) + 1, + ) # Print the total area of fraction of each phase if print_fractions: phase_mask = phase_scale >= 0.5 phase_total = np.sum(phase_mask) - print('Phase Fractions') - print('---------------') + print("Phase Fractions") + print("---------------") for a0 in range(self.num_crystals): phase_frac = np.sum((phase_map == a0) * phase_mask) / phase_total - print(self.crystal_names[a0] + ' - ' + f'{phase_frac*100:.4f}' + '%') - + print(self.crystal_names[a0] + " - " + f"{phase_frac*100:.4f}" + "%") - self.phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) + self.phase_rgb = np.zeros((scan_shape[0], scan_shape[1], 3)) for a0 in range(self.num_crystals): - sub = phase_map==a0 + sub = phase_map == a0 for a1 in range(3): - self.phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] + self.phase_rgb[:, :, a1][sub] = phase_colors[a0, a1] * phase_scale[sub] # normalize # self.phase_rgb = np.clip( # (self.phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), # 0,1) - - - + fig = plt.figure(figsize=figsize) if legend_add: width = 1 - ax = fig.add_axes((0,legend_fraction,1,1-legend_fraction)) - ax_leg = fig.add_axes((0,0,1,legend_fraction)) + ax = fig.add_axes((0, legend_fraction, 1, 1 - legend_fraction)) + ax_leg = fig.add_axes((0, 0, 1, legend_fraction)) for a0 in range(self.num_crystals): ax_leg.scatter( - a0*width, + a0 * width, 0, - s = 200, - marker = 's', - edgecolor = (0,0,0,1), - facecolor = phase_colors[a0], + s=200, + marker="s", + edgecolor=(0, 0, 0, 1), + facecolor=phase_colors[a0], ) ax_leg.text( - a0*width+0.1, + a0 * width + 0.1, 0, self.crystal_names[a0], - fontsize = 16, - verticalalignment = 'center', + fontsize=16, + verticalalignment="center", + ) + ax_leg.axis("off") + ax_leg.set_xlim( + ( + width * -0.5, + width * (self.num_crystals + 0.5), ) - ax_leg.axis('off') - ax_leg.set_xlim(( - width * -0.5, - width * (self.num_crystals+0.5), - )) + ) else: - ax = fig.add_axes((0,0,1,1)) + ax = fig.add_axes((0, 0, 1, 1)) ax.imshow( self.phase_rgb, - # vmin = 0, - # vmax = 5, - # self.phase_rgb, - # phase_corr - phase_corr_2nd, - # cmap = 'turbo', - # vmin = 0, - # vmax = 3, - # cmap = 'gray', ) if ticks is False: ax.set_xticks([]) ax.set_yticks([]) - if returnfig: - return fig,ax - + return fig, ax diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 649f26faa..56d4520ee 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -848,7 +848,7 @@ def plot_orientation_plan( bragg_peaks = self.generate_diffraction_pattern( orientation_matrix=self.orientation_rotation_matrices[index_plot, :], sigma_excitation_error=self.orientation_kernel_size / 3, - precession_angle_degrees = self.orientation_precession_angle_degrees, + precession_angle_degrees=self.orientation_precession_angle_degrees, ) plot_diffraction_pattern( @@ -949,19 +949,16 @@ def plot_diffraction_pattern( scale_markers_compare: Optional[float] = None, power_markers: float = 1, power_markers_compare: float = 1, - - color = (0.0, 0.0, 0.0), - color_compare = None, - facecolor = None, - facecolor_compare = (0.0, 0.7, 1.0), - edgecolor = None, - edgecolor_compare = None, - linewidth = 1, - linewidth_compare = 1, - - marker = '+', - marker_compare = 'o', - + color=(0.0, 0.0, 0.0), + color_compare=None, + facecolor=None, + facecolor_compare=(0.0, 0.7, 1.0), + edgecolor=None, + edgecolor_compare=None, + linewidth=1, + linewidth_compare=1, + marker="+", + marker_compare="o", plot_range_kx_ky: Optional[Union[list, tuple, np.ndarray]] = None, add_labels: bool = True, shift_labels: float = 0.08, @@ -1004,9 +1001,7 @@ def plot_diffraction_pattern( if power_markers == 2: marker_size = scale_markers * bragg_peaks.data["intensity"] else: - marker_size = scale_markers * ( - bragg_peaks.data["intensity"] ** power_markers - ) + marker_size = scale_markers * (bragg_peaks.data["intensity"] ** power_markers) # Apply marker size limits to primary plot marker_size = np.clip(marker_size, min_marker_size, max_marker_size) @@ -1061,7 +1056,6 @@ def plot_diffraction_pattern( ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) - if plot_range_kx_ky is not None: plot_range_kx_ky = np.array(plot_range_kx_ky) if plot_range_kx_ky.ndim == 0: diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index ea83cc866..34630dbc2 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -17,8 +17,8 @@ def fit_amorphous_ring( fit_all_images=False, maxfev=None, robust=False, - robust_steps = 3, - robust_thresh = 1.0, + robust_steps=3, + robust_thresh=1.0, verbose=False, plot_result=True, plot_log_scale=False, @@ -229,7 +229,7 @@ def fit_amorphous_ring( if maxfev is None: coefs = curve_fit( amorphous_model, - basis[:,sub_fit], + basis[:, sub_fit], vals[sub_fit] / int_mean, p0=coefs, xtol=1e-8, @@ -238,7 +238,7 @@ def fit_amorphous_ring( else: coefs = curve_fit( amorphous_model, - basis[:,sub_fit], + basis[:, sub_fit], vals[sub_fit] / int_mean, p0=coefs, xtol=1e-8, @@ -250,7 +250,6 @@ def fit_amorphous_ring( # Scale intensity coefficients coefs[5:8] *= int_mean - # Perform the fit on each individual diffration pattern if fit_all_images: coefs_all = np.zeros((datacube.shape[0], datacube.shape[1], coefs.size)) From 307f27dcf5aab391d5d2a457e1caddac25c0f973 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Thu, 4 Jul 2024 09:33:21 -0700 Subject: [PATCH 44/74] lite refactor commits from dev --- py4DSTEM/__init__.py | 10 ++++++++-- py4DSTEM/braggvectors/diskdetection.py | 5 ++++- py4DSTEM/datacube/virtualimage.py | 5 +++-- py4DSTEM/io/filereaders/__init__.py | 5 ++++- py4DSTEM/io/importfile.py | 2 +- py4DSTEM/process/__init__.py | 10 ++++++++-- py4DSTEM/process/diffraction/crystal_viz.py | 4 +++- py4DSTEM/process/phase/__init__.py | 5 ++++- py4DSTEM/process/phase/utils.py | 2 +- py4DSTEM/process/polar/polar_analysis.py | 3 ++- py4DSTEM/process/polar/polar_peaks.py | 3 ++- py4DSTEM/utils/configuration_checker.py | 8 ++------ 12 files changed, 42 insertions(+), 20 deletions(-) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index d5df63f5e..712095abd 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -52,7 +52,10 @@ BraggVectorMap, ) -from py4DSTEM.process import classification +try: + from py4DSTEM.process import classification +except (ImportError, ModuleNotFoundError): + pass # diffraction @@ -70,7 +73,10 @@ # strain from py4DSTEM.process.strain.strain import StrainMap -from py4DSTEM.process import wholepatternfit +try: + from py4DSTEM.process import wholepatternfit +except (ImportError, ModuleNotFoundError): + pass ### more submodules diff --git a/py4DSTEM/braggvectors/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py index 99818b75e..e329d193e 100644 --- a/py4DSTEM/braggvectors/diskdetection.py +++ b/py4DSTEM/braggvectors/diskdetection.py @@ -10,7 +10,10 @@ from py4DSTEM.datacube import DataCube from py4DSTEM.preprocess.utils import get_maxima_2D from py4DSTEM.process.utils.cross_correlate import get_cross_correlation_FT -from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml +try: + from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml +except (ImportError, ModuleNotFoundError): + pass def find_Bragg_disks( diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index 627223d23..d4fe15241 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -5,7 +5,6 @@ # for bragg virtual imaging methods, goto diskdetection.virtualimage.py import numpy as np -import dask.array as da from typing import Optional import inspect @@ -220,7 +219,9 @@ def get_virtual_image( virtual_image[rx, ry] = np.sum(self.data[rx, ry] * mask) # dask - if dask is True: + if dask: + import dask.array as da + # set up a generalized universal function for dask distribution def _apply_mask_dask(self, mask): virtual_image = np.sum( diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py index b6f4eb0a2..eb5cd02e1 100644 --- a/py4DSTEM/io/filereaders/__init__.py +++ b/py4DSTEM/io/filereaders/__init__.py @@ -2,5 +2,8 @@ from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_mib import load_mib -from py4DSTEM.io.filereaders.read_arina import read_arina +try: + from py4DSTEM.io.filereaders.read_arina import read_arina +except (ImportError, ModuleNotFoundError): + pass from py4DSTEM.io.filereaders.read_abTEM import read_abTEM diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py index ff3d1c37c..73e402de0 100644 --- a/py4DSTEM/io/importfile.py +++ b/py4DSTEM/io/importfile.py @@ -7,7 +7,6 @@ from py4DSTEM.io.filereaders import ( load_mib, read_abTEM, - read_arina, read_dm, read_empad, read_gatan_K2_bin, @@ -90,6 +89,7 @@ def import_file( elif filetype == "mib": data = load_mib(filepath, mem=mem, binfactor=binfactor, **kwargs) elif filetype == "arina": + from py4DSTEM.io.filereaders import read_arina data = read_arina(filepath, mem=mem, binfactor=binfactor, **kwargs) elif filetype == "abTEM": data = read_abTEM(filepath, mem=mem, binfactor=binfactor, **kwargs) diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index 0509d181e..8acee4e3c 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -4,6 +4,12 @@ from py4DSTEM.process import phase from py4DSTEM.process import calibration from py4DSTEM.process import utils -from py4DSTEM.process import classification +try: + from py4DSTEM.process import classification +except (ImportError, ModuleNotFoundError): + pass from py4DSTEM.process import diffraction -from py4DSTEM.process import wholepatternfit +try: + from py4DSTEM.process import wholepatternfit +except (ImportError, ModuleNotFoundError): + pass diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 47df2e6ca..d3e4b682a 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -8,7 +8,6 @@ from scipy.signal import medfilt from scipy.ndimage import gaussian_filter from scipy.ndimage import distance_transform_edt -from skimage.morphology import dilation, erosion import warnings import numpy as np @@ -1884,7 +1883,10 @@ def plot_clusters( for a0 in range(self.cluster_sizes.shape[0]): if self.cluster_sizes[a0] >= area_min: + if outline_grains: + from skimage.morphology import erosion + im_grain[:] = False im_grain[ self.cluster_inds[a0][0, :], diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index ecfeaa1d2..927e5a08a 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -11,6 +11,9 @@ from py4DSTEM.process.phase.parallax import Parallax from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography -from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer +try: + from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer +except (ImportError, ModuleNotFoundError): + pass # fmt: on diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 5742ff7e7..2834e8c2e 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -22,7 +22,6 @@ def get_array_module(*args): from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from skimage.restoration import unwrap_phase # fmt: off @@ -1755,6 +1754,7 @@ def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np def unwrap_phase_2d_skimage(array, corner_centered=True, xp=np): + from skimage.restoration import unwrap_phase if xp is np: array = array.astype(np.float64) unwrapped_array = unwrap_phase(array, wrap_around=corner_centered).astype( diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 78a95c24a..dbf16ff97 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -5,7 +5,6 @@ import matplotlib.pyplot as plt from scipy.optimize import curve_fit from scipy.ndimage import gaussian_filter -from sklearn.decomposition import PCA from emdfile import tqdmnd @@ -980,6 +979,8 @@ def background_pca( radial PCA component selected """ + from sklearn.decomposition import PCA + # PCA decomposition shape = self.radial_all.shape A = np.reshape(self.radial_all, (shape[0] * shape[1], shape[2])) diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 535ae7143..6e18e753c 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -3,7 +3,6 @@ from scipy.ndimage import gaussian_filter, gaussian_filter1d from scipy.signal import peak_prominences -from skimage.feature import peak_local_max from scipy.optimize import curve_fit, leastsq import warnings @@ -105,6 +104,8 @@ def find_peaks_single_pattern( """ + from skimage.feature import peak_local_max + # if needed, generate mask from Bragg peaks if bragg_peaks is not None: mask_bragg = self._datacube.get_braggmask( diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index b50a21de2..6edf44ffc 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -7,10 +7,6 @@ # need a mapping of pypi/conda names to import names import_mapping_dict = { - "scikit-image": "skimage", - "scikit-learn": "sklearn", - "scikit-optimize": "skopt", - "mp-api": "mp_api", } @@ -88,7 +84,8 @@ def get_modules_dict(): # module_depenencies = get_modules_dict() -modules = get_modules_list() +# modules = get_modules_list() +modules = [] #### Class and Functions to Create Coloured Strings #### @@ -527,7 +524,6 @@ def print_no_extra_checks(m: str): # dict of extra check functions funcs_dict = { - "cupy": check_cupy_gpu, } From fdce7a920a1c33c1fcdd55f5d580bdde7d8c9de3 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Thu, 4 Jul 2024 09:54:26 -0700 Subject: [PATCH 45/74] black --- py4DSTEM/braggvectors/diskdetection.py | 1 + py4DSTEM/io/filereaders/__init__.py | 1 + py4DSTEM/io/importfile.py | 1 + py4DSTEM/process/__init__.py | 2 ++ py4DSTEM/process/diffraction/crystal_viz.py | 4 ++-- py4DSTEM/process/phase/utils.py | 5 ++++- py4DSTEM/utils/configuration_checker.py | 6 ++---- 7 files changed, 13 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/braggvectors/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py index e329d193e..a66ff62ea 100644 --- a/py4DSTEM/braggvectors/diskdetection.py +++ b/py4DSTEM/braggvectors/diskdetection.py @@ -10,6 +10,7 @@ from py4DSTEM.datacube import DataCube from py4DSTEM.preprocess.utils import get_maxima_2D from py4DSTEM.process.utils.cross_correlate import get_cross_correlation_FT + try: from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml except (ImportError, ModuleNotFoundError): diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py index eb5cd02e1..372275f11 100644 --- a/py4DSTEM/io/filereaders/__init__.py +++ b/py4DSTEM/io/filereaders/__init__.py @@ -2,6 +2,7 @@ from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_mib import load_mib + try: from py4DSTEM.io.filereaders.read_arina import read_arina except (ImportError, ModuleNotFoundError): diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py index 73e402de0..b3002c77e 100644 --- a/py4DSTEM/io/importfile.py +++ b/py4DSTEM/io/importfile.py @@ -90,6 +90,7 @@ def import_file( data = load_mib(filepath, mem=mem, binfactor=binfactor, **kwargs) elif filetype == "arina": from py4DSTEM.io.filereaders import read_arina + data = read_arina(filepath, mem=mem, binfactor=binfactor, **kwargs) elif filetype == "abTEM": data = read_abTEM(filepath, mem=mem, binfactor=binfactor, **kwargs) diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index 8acee4e3c..b04451913 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -4,11 +4,13 @@ from py4DSTEM.process import phase from py4DSTEM.process import calibration from py4DSTEM.process import utils + try: from py4DSTEM.process import classification except (ImportError, ModuleNotFoundError): pass from py4DSTEM.process import diffraction + try: from py4DSTEM.process import wholepatternfit except (ImportError, ModuleNotFoundError): diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index d3e4b682a..b5a22d609 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -453,7 +453,8 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) + * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, @@ -1883,7 +1884,6 @@ def plot_clusters( for a0 in range(self.cluster_sizes.shape[0]): if self.cluster_sizes[a0] >= area_min: - if outline_grains: from skimage.morphology import erosion diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 2834e8c2e..0ed82f703 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -202,7 +202,9 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) + return xp.exp( + -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 + ) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] @@ -1755,6 +1757,7 @@ def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np def unwrap_phase_2d_skimage(array, corner_centered=True, xp=np): from skimage.restoration import unwrap_phase + if xp is np: array = array.astype(np.float64) unwrapped_array = unwrap_phase(array, wrap_around=corner_centered).astype( diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index 6edf44ffc..cbdbd2881 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -6,8 +6,7 @@ from importlib.util import find_spec # need a mapping of pypi/conda names to import names -import_mapping_dict = { -} +import_mapping_dict = {} # programatically get all possible requirements in the import name style @@ -523,8 +522,7 @@ def print_no_extra_checks(m: str): # dict of extra check functions -funcs_dict = { -} +funcs_dict = {} #### main function used to check the configuration of the installation From b84847cc0969aef74d3bb258103ea85c0aac58af Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Thu, 4 Jul 2024 10:03:46 -0700 Subject: [PATCH 46/74] one more try, new black --- 24 | 16 ++++++++++++++++ py4DSTEM/process/diffraction/crystal_viz.py | 3 +-- py4DSTEM/process/phase/utils.py | 4 +--- 3 files changed, 18 insertions(+), 5 deletions(-) create mode 100644 24 diff --git a/24 b/24 new file mode 100644 index 000000000..16bdba3c0 --- /dev/null +++ b/24 @@ -0,0 +1,16 @@ +conda-forge/linux-64 Using cache +conda-forge/noarch Using cache +Transaction + + Prefix: /home/george/mambaforge/envs/py4dstem-dev + + All requested packages already installed + + +Looking for: ['black'] + + +Pinned packages: + - python 3.12.* + + diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index b5a22d609..ecc6fed0e 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -453,8 +453,7 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) - * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 0ed82f703..f932b40b5 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -202,9 +202,7 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp( - -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 - ) + return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] From 7b60b13f8e33cfc9e0a17b10275fd20811a67021 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Thu, 4 Jul 2024 10:07:40 -0700 Subject: [PATCH 47/74] odd mamba file made its way in --- 24 | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 24 diff --git a/24 b/24 deleted file mode 100644 index 16bdba3c0..000000000 --- a/24 +++ /dev/null @@ -1,16 +0,0 @@ -conda-forge/linux-64 Using cache -conda-forge/noarch Using cache -Transaction - - Prefix: /home/george/mambaforge/envs/py4dstem-dev - - All requested packages already installed - - -Looking for: ['black'] - - -Pinned packages: - - python 3.12.* - - From 1e8278f8c53cc39a4e6a5fa82a0a3be8a587df46 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Wed, 10 Jul 2024 01:59:00 -0700 Subject: [PATCH 48/74] is_package_lite demo --- py4DSTEM/__init__.py | 14 +++++++++----- setup.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index 712095abd..695d2b55d 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -1,6 +1,9 @@ from py4DSTEM.version import __version__ from emdfile import tqdmnd +from importlib.metadata import metadata + +is_package_lite = "lite" in metadata("py4DSTEM")["Keywords"].lower().split(",") ### io @@ -54,9 +57,9 @@ try: from py4DSTEM.process import classification -except (ImportError, ModuleNotFoundError): - pass - +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc # diffraction from py4DSTEM.process.diffraction import Crystal, Orientation @@ -75,8 +78,9 @@ try: from py4DSTEM.process import wholepatternfit -except (ImportError, ModuleNotFoundError): - pass +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc ### more submodules diff --git a/setup.py b/setup.py index 3a853cc9d..28b16692a 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ author="Benjamin H. Savitzky", author_email="ben.savitzky@gmail.com", license="GNU GPLv3", - keywords="STEM 4DSTEM", + keywords="STEM,4DSTEM", python_requires=">=3.10", install_requires=[ "numpy >= 1.19", From e872c79f98e2c7c557d510cb4b35e70298b79264 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Wed, 10 Jul 2024 02:06:52 -0700 Subject: [PATCH 49/74] switching to __package__ calls instead --- py4DSTEM/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index 695d2b55d..2fd84dc0a 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import metadata -is_package_lite = "lite" in metadata("py4DSTEM")["Keywords"].lower().split(",") +is_package_lite = "lite" in metadata(__package__)["Keywords"].lower().split(",") ### io From 44c5e1ac5d602eed9d6023e2cefa47cb26b0ff2d Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Wed, 10 Jul 2024 02:20:50 -0700 Subject: [PATCH 50/74] I clearly dont understand __package__ --- py4DSTEM/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index 2fd84dc0a..0189cb2be 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -3,6 +3,9 @@ from importlib.metadata import metadata +package_spec = __spec__ +package_package = __package__ +package_name = __name__ is_package_lite = "lite" in metadata(__package__)["Keywords"].lower().split(",") ### io From 33d20b6fed62f5b5057cdcd4469e67835878fae9 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Wed, 10 Jul 2024 02:30:35 -0700 Subject: [PATCH 51/74] reverting, doesnt seem like metadata would work - will continue on plane --- py4DSTEM/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index 0189cb2be..2fd84dc0a 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -3,9 +3,6 @@ from importlib.metadata import metadata -package_spec = __spec__ -package_package = __package__ -package_name = __name__ is_package_lite = "lite" in metadata(__package__)["Keywords"].lower().split(",") ### io From 1693c18805eb4122b9569bd0a2596d5a7d26fb73 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Wed, 10 Jul 2024 19:09:57 -0700 Subject: [PATCH 52/74] alright - this hack ought to work --- py4DSTEM/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index 2fd84dc0a..0f4490eb2 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -1,9 +1,9 @@ from py4DSTEM.version import __version__ from emdfile import tqdmnd -from importlib.metadata import metadata +from importlib.metadata import packages_distributions -is_package_lite = "lite" in metadata(__package__)["Keywords"].lower().split(",") +is_package_lite = "py4DSTEM-lite" in packages_distributions()["py4DSTEM"] ### io From f75992cecde8923347640cc4ed5e8ff1ff9dcd39 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Wed, 10 Jul 2024 19:27:16 -0700 Subject: [PATCH 53/74] updating all __init__ files to be lite_aware --- py4DSTEM/braggvectors/diskdetection.py | 6 ++++-- py4DSTEM/io/filereaders/__init__.py | 8 +++++--- py4DSTEM/process/__init__.py | 12 ++++++++---- py4DSTEM/process/phase/__init__.py | 6 ++++-- py4DSTEM/utils/configuration_checker.py | 23 ++++++++++++++++++++--- 5 files changed, 41 insertions(+), 14 deletions(-) diff --git a/py4DSTEM/braggvectors/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py index a66ff62ea..500dbd2e9 100644 --- a/py4DSTEM/braggvectors/diskdetection.py +++ b/py4DSTEM/braggvectors/diskdetection.py @@ -5,6 +5,7 @@ from scipy.ndimage import gaussian_filter from emdfile import tqdmnd +from py4DSTEM import is_package_lite from py4DSTEM.braggvectors.braggvectors import BraggVectors from py4DSTEM.data import QPoints from py4DSTEM.datacube import DataCube @@ -13,8 +14,9 @@ try: from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml -except (ImportError, ModuleNotFoundError): - pass +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc def find_Bragg_disks( diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py index 372275f11..4b89175a7 100644 --- a/py4DSTEM/io/filereaders/__init__.py +++ b/py4DSTEM/io/filereaders/__init__.py @@ -1,10 +1,12 @@ +from py4DSTEM import is_package_lite +from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_dm import read_dm from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin -from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_mib import load_mib try: from py4DSTEM.io.filereaders.read_arina import read_arina -except (ImportError, ModuleNotFoundError): - pass +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc from py4DSTEM.io.filereaders.read_abTEM import read_abTEM diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index b04451913..6d7d36b28 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,3 +1,4 @@ +from py4DSTEM import is_package_lite from py4DSTEM.process.polar import PolarDatacube from py4DSTEM.process.strain.strain import StrainMap @@ -7,11 +8,14 @@ try: from py4DSTEM.process import classification -except (ImportError, ModuleNotFoundError): - pass +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc + from py4DSTEM.process import diffraction try: from py4DSTEM.process import wholepatternfit -except (ImportError, ModuleNotFoundError): - pass +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 927e5a08a..68dac8989 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -2,6 +2,7 @@ _emd_hook = True +from py4DSTEM import is_package_lite from py4DSTEM.process.phase.dpc import DPC from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography @@ -13,7 +14,8 @@ from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography try: from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer -except (ImportError, ModuleNotFoundError): - pass +except (ImportError, ModuleNotFoundError) as exc: + if not is_package_lite: + raise exc # fmt: on diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index cbdbd2881..2d292618e 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -5,8 +5,19 @@ import re from importlib.util import find_spec +from py4DSTEM import is_package_lite + # need a mapping of pypi/conda names to import names -import_mapping_dict = {} +import_mapping_dict = ( + {} + if is_package_lite + else { + "scikit-image": "skimage", + "scikit-learn": "sklearn", + "scikit-optimize": "skopt", + "mp-api": "mp_api", + } +) # programatically get all possible requirements in the import name style @@ -84,7 +95,7 @@ def get_modules_dict(): # module_depenencies = get_modules_dict() # modules = get_modules_list() -modules = [] +modules = [] if is_package_lite else get_modules_list() #### Class and Functions to Create Coloured Strings #### @@ -522,7 +533,13 @@ def print_no_extra_checks(m: str): # dict of extra check functions -funcs_dict = {} +funcs_dict = ( + {} + if is_package_lite + else { + "cupy": check_cupy_gpu, + } +) #### main function used to check the configuration of the installation From 88994879694c43680926ea7d9bddb649edf75412 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 17 Jul 2024 20:18:04 -0700 Subject: [PATCH 54/74] black formatting, updated variable name --- .../diffraction/WK_scattering_factors.py | 4 +- py4DSTEM/process/diffraction/crystal_ACOM.py | 1 - py4DSTEM/process/diffraction/crystal_phase.py | 68 ++++++++++--------- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 ++++---- .../process/phase/magnetic_ptychography.py | 28 ++++---- py4DSTEM/process/phase/parallax.py | 24 +++---- .../process/phase/ptychographic_methods.py | 4 +- .../process/phase/ptychographic_tomography.py | 28 ++++---- py4DSTEM/process/phase/utils.py | 4 +- py4DSTEM/process/utils/utils.py | 18 ++++- 12 files changed, 117 insertions(+), 96 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index eb964de96..70110a977 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,7 +221,9 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( + BJ / (BI + BJ) + ) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index ae71c20bc..f7a4a20db 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1052,7 +1052,6 @@ def match_single_pattern( sub = dqr < self.orientation_kernel_size if np.any(sub): - im_polar[ind_radial, :] = np.sum( np.power( np.maximum(intensity[sub, None], 0.0), diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 3db718555..c161304f9 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -10,7 +10,7 @@ from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -class Crystal_Phase: +class CrystalPhase: """ A class storing multiple crystal structures, and associated diffraction data. Must be initialized after matching orientations to a pointlistarray??? @@ -194,8 +194,8 @@ def quantify_single_pattern( """ - # tolerance - tol2 = 1e-6 + # tolerance for separating the origin peak. + tolerance_origin_2 = 1e-6 # calibrations center = pointlistarray.calstate["center"] @@ -231,11 +231,15 @@ def quantify_single_pattern( ) # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() if k_max is None: - keep = bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2 + keep = ( + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 + > tolerance_origin_2 + ) else: keep = np.logical_and.reduce( ( - bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2, + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 + > tolerance_origin_2, np.abs(bragg_peaks.data["qx"]) < k_max, np.abs(bragg_peaks.data["qy"]) < k_max, ) @@ -303,14 +307,14 @@ def quantify_single_pattern( if k_max is None: del_peak = ( bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2 - < tol2 + < tolerance_origin_2 ) else: del_peak = np.logical_or.reduce( ( bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2 - < tol2, + < tolerance_origin_2, np.abs(bragg_peaks_fit.data["qx"]) > k_max, np.abs(bragg_peaks_fit.data["qy"]) > k_max, ) @@ -473,7 +477,6 @@ def quantify_single_pattern( search = True while search is True: - basis_solve = basis[:, inds_solve] obs_solve = obs.copy() @@ -685,7 +688,6 @@ def quantify_single_pattern( crystal_inds_plot == None or np.min(np.abs(c - crystal_inds_plot)) == 0 ): - qx_fit = library_peaks[a0].data["qx"] qy_fit = library_peaks[a0].data["qy"] @@ -701,7 +703,6 @@ def quantify_single_pattern( matches_fit = library_matches[a0] if plot_only_nonzero_phases is False or phase_weights[a0] > 0: - # if np.mod(m,2) == 0: ax.scatter( qy_fit[matches_fit], @@ -877,27 +878,30 @@ def quantify_phase( disable=not progress_bar, ): # calculate phase weights - phase_weights, phase_residual, phase_reliability, int_peaks = ( - self.quantify_single_pattern( - pointlistarray=pointlistarray, - xy_position=(rx, ry), - corr_kernel_size=corr_kernel_size, - sigma_excitation_error=sigma_excitation_error, - precession_angle_degrees=precession_angle_degrees, - power_intensity=power_intensity, - power_intensity_experiment=power_intensity_experiment, - k_max=k_max, - max_number_patterns=max_number_patterns, - single_phase=single_phase, - allow_strain=allow_strain, - strain_iterations=strain_iterations, - strain_max=strain_max, - include_false_positives=include_false_positives, - weight_false_positives=weight_false_positives, - plot_result=False, - verbose=False, - returnfig=False, - ) + ( + phase_weights, + phase_residual, + phase_reliability, + int_peaks, + ) = self.quantify_single_pattern( + pointlistarray=pointlistarray, + xy_position=(rx, ry), + corr_kernel_size=corr_kernel_size, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, + power_intensity=power_intensity, + power_intensity_experiment=power_intensity_experiment, + k_max=k_max, + max_number_patterns=max_number_patterns, + single_phase=single_phase, + allow_strain=allow_strain, + strain_iterations=strain_iterations, + strain_max=strain_max, + include_false_positives=include_false_positives, + weight_false_positives=weight_false_positives, + plot_result=False, + verbose=False, + returnfig=False, ) self.phase_weights[rx, ry] = phase_weights self.phase_residuals[rx, ry] = phase_residual @@ -1178,7 +1182,6 @@ def plot_phase_maps( ) for a0 in range(self.num_crystals): - ax[a0].imshow( im_rgb_all[a0], ) @@ -1315,7 +1318,6 @@ def plot_dominant_phase( ) else: - # find the second correlation score for each crystal and match index for a0 in range(self.num_crystals): corr = phase_sig[a0].copy() diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 56d4520ee..1663a6198 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,7 +454,8 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) + * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 5c2d56a3c..9973ff79f 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,7 +169,8 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index be6332c74..9ffcc6212 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1196,20 +1196,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 19f306188..0949a02eb 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1497,20 +1497,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index e12d6a133..3d8c09635 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2202,16 +2202,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 0] - ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 1] - ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2222,16 +2222,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 0] - ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 1] - ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 283ddb1ba..cbc0b2fde 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -356,7 +356,9 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 3639096dc..e1dc33abb 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1088,20 +1088,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 5742ff7e7..1b12da6a0 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,7 +203,9 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) + return xp.exp( + -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 + ) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index ddeeb2c36..60da616d1 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,7 +93,12 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) return lam @@ -102,8 +107,15 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 - sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) + sigma = ( + (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + ) return sigma From 0ecf427c287b217d635b6200e436c3e61782c22c Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 17 Jul 2024 20:20:39 -0700 Subject: [PATCH 55/74] updating black version --- .../diffraction/WK_scattering_factors.py | 4 +-- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 +++++++++---------- .../process/phase/magnetic_ptychography.py | 28 +++++++++---------- py4DSTEM/process/phase/parallax.py | 24 ++++++++-------- .../process/phase/ptychographic_methods.py | 4 +-- .../process/phase/ptychographic_tomography.py | 28 +++++++++---------- py4DSTEM/process/phase/utils.py | 4 +-- py4DSTEM/process/utils/utils.py | 18 ++---------- 10 files changed, 62 insertions(+), 82 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index 70110a977..eb964de96 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,9 +221,7 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( - BJ / (BI + BJ) - ) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 1663a6198..56d4520ee 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,8 +454,7 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) - * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 9973ff79f..5c2d56a3c 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,8 +169,7 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) - + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 9ffcc6212..be6332c74 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1196,20 +1196,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 0949a02eb..19f306188 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1497,20 +1497,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 3d8c09635..e12d6a133 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2202,16 +2202,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 0] + measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 0] + ) measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 1] + measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 1] + ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2222,16 +2222,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 0] + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 0] + ) fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 1] + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 1] + ) max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index cbc0b2fde..283ddb1ba 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -356,9 +356,7 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) + propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index e1dc33abb..3639096dc 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1088,20 +1088,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 1b12da6a0..5742ff7e7 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,9 +203,7 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp( - -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 - ) + return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 60da616d1..ddeeb2c36 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,12 +93,7 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 return lam @@ -107,15 +102,8 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) - sigma = ( - (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) return sigma From 7c2eb0419e25eaee2d5186038b6188aa6a27e87f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 18 Jul 2024 13:27:20 -0700 Subject: [PATCH 56/74] switch to centered alignment bins by default --- py4DSTEM/process/phase/parallax.py | 5 ++++- py4DSTEM/process/phase/parameter_optimize.py | 16 +++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index e12d6a133..fa86f4dc2 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -743,6 +743,7 @@ def reconstruct( min_alignment_bin: int = 1, num_iter_at_min_bin: int = 2, alignment_bin_values: list = None, + centered_alignment_bins: bool = True, cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = False, @@ -879,6 +880,8 @@ def reconstruct( (bin_vals, np.repeat(bin_vals[-1], num_iter_at_min_bin - 1)) ) + bin_shift = 0 if centered_alignment_bins else 0.5 + if plot_aligned_bf: num_plots = bin_vals.shape[0] nrows = int(np.sqrt(num_plots)) @@ -913,7 +916,7 @@ def reconstruct( G_ref = xp.fft.fft2(self._recon_BF) # Segment the virtual images with current binning values - xy_inds = xp.round(xy_center / bin_vals[a0] + 0.5).astype("int") + xy_inds = xp.round(xy_center / bin_vals[a0] + bin_shift).astype("int") xy_vals = np.unique( asnumpy(xy_inds), axis=0 ) # axis is not yet supported in cupy diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index fc1e84f11..e51d3cd2c 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -458,13 +458,15 @@ def _split_static_and_optimization_vars(self, argdict): return static_args, optimization_args def _get_scan_positions(self, affine_transform, dataset): - R_pixel_size = dataset.calibration.get_R_pixel_size() - x, y = ( - np.arange(dataset.R_Nx) * R_pixel_size, - np.arange(dataset.R_Ny) * R_pixel_size, - ) - x, y = np.meshgrid(x, y, indexing="ij") - scan_positions = np.stack((x.ravel(), y.ravel()), axis=1) + scan_positions = self._init_static_args.pop("initial_scan_positions",None) + if scan_positions is None: + R_pixel_size = dataset.calibration.get_R_pixel_size() + x, y = ( + np.arange(dataset.R_Nx) * R_pixel_size, + np.arange(dataset.R_Ny) * R_pixel_size, + ) + x, y = np.meshgrid(x, y, indexing="ij") + scan_positions = np.stack((x.ravel(), y.ravel()), axis=1) scan_positions = scan_positions @ affine_transform.asarray() return scan_positions From 394c0daa12f2ef9b45edb76366354bb8e8a0813d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 18 Jul 2024 13:28:59 -0700 Subject: [PATCH 57/74] forgot to black small BO change --- py4DSTEM/process/phase/parameter_optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index e51d3cd2c..149e8143b 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -458,7 +458,7 @@ def _split_static_and_optimization_vars(self, argdict): return static_args, optimization_args def _get_scan_positions(self, affine_transform, dataset): - scan_positions = self._init_static_args.pop("initial_scan_positions",None) + scan_positions = self._init_static_args.pop("initial_scan_positions", None) if scan_positions is None: R_pixel_size = dataset.calibration.get_R_pixel_size() x, y = ( From 564999ced7732db821b8d628e466be52c89d684e Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:07:42 -0700 Subject: [PATCH 58/74] testing fix for drawing legend --- py4DSTEM/process/diffraction/crystal_viz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 56d4520ee..ab4b259f1 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -1416,8 +1416,8 @@ def plot_orientation_maps( # Triangulate faces p = self.orientation_vecs[:, (1, 0, 2)] tri = mtri.Triangulation( - self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, - self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, + self.orientation_inds[:, 1].astype('float') - self.orientation_inds[:, 0].astype('float') * 1e-3, + self.orientation_inds[:, 0].astype('float') - self.orientation_inds[:, 1].astype('float') * 1e-3, ) # convert rgb values of pixels to faces rgb_faces = ( From fece355625672eb90dd2a8e987fd668d46b91be2 Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:21:35 -0700 Subject: [PATCH 59/74] Revert "testing fix for drawing legend" This reverts commit 564999ced7732db821b8d628e466be52c89d684e. --- py4DSTEM/process/diffraction/crystal_viz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index ab4b259f1..56d4520ee 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -1416,8 +1416,8 @@ def plot_orientation_maps( # Triangulate faces p = self.orientation_vecs[:, (1, 0, 2)] tri = mtri.Triangulation( - self.orientation_inds[:, 1].astype('float') - self.orientation_inds[:, 0].astype('float') * 1e-3, - self.orientation_inds[:, 0].astype('float') - self.orientation_inds[:, 1].astype('float') * 1e-3, + self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, + self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, ) # convert rgb values of pixels to faces rgb_faces = ( From a70970588fd1322f17feefb52ecff7b0f8458462 Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:38:54 -0700 Subject: [PATCH 60/74] adding option to hide legend --- py4DSTEM/process/diffraction/crystal_viz.py | 437 ++++++++++---------- 1 file changed, 221 insertions(+), 216 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 56d4520ee..88cc16853 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -538,6 +538,7 @@ def plot_orientation_zones( proj_dir_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, tol_den=10, marker_size: float = 20, + plot_zone_axis_labels: bool = True, plot_limit: Union[list, tuple, np.ndarray] = np.array([-1.1, 1.1]), figsize: Union[list, tuple, np.ndarray] = (8, 8), returnfig: bool = False, @@ -553,6 +554,7 @@ def plot_orientation_zones( dir_proj (float): projection direction, either [elev azim] or normal vector Default is mean vector of self.orientation_zone_axis_range rows marker_size (float): size of markers + plot_zone_axis_labels (bool): plot the zone axis labels plot_limit (float): x y z plot limits, default is [0, 1.05] figsize (2 element float): size scaling of figure axes returnfig (bool): set to True to return figure and axes handles @@ -727,47 +729,47 @@ def plot_orientation_zones( zorder=10, ) - text_scale_pos = 1.2 - text_params = { - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "color": "k", - "size": 16, - } - # 'ha': 'center', - - ax.text( - self.orientation_vecs[inds[0], 1] * text_scale_pos, - self.orientation_vecs[inds[0], 0] * text_scale_pos, - self.orientation_vecs[inds[0], 2] * text_scale_pos, - label_0, - None, - zorder=11, - ha="center", - **text_params, - ) - if self.orientation_full is False and self.orientation_half is False: - ax.text( - self.orientation_vecs[inds[1], 1] * text_scale_pos, - self.orientation_vecs[inds[1], 0] * text_scale_pos, - self.orientation_vecs[inds[1], 2] * text_scale_pos, - label_1, - None, - zorder=12, - ha="center", - **text_params, - ) + if plot_zone_axis_labels: + text_scale_pos = 1.2 + text_params = { + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 16, + } + # 'ha': 'center', ax.text( - self.orientation_vecs[inds[2], 1] * text_scale_pos, - self.orientation_vecs[inds[2], 0] * text_scale_pos, - self.orientation_vecs[inds[2], 2] * text_scale_pos, - label_2, + self.orientation_vecs[inds[0], 1] * text_scale_pos, + self.orientation_vecs[inds[0], 0] * text_scale_pos, + self.orientation_vecs[inds[0], 2] * text_scale_pos, + label_0, None, - zorder=13, + zorder=11, ha="center", **text_params, ) + if self.orientation_full is False and self.orientation_half is False: + ax.text( + self.orientation_vecs[inds[1], 1] * text_scale_pos, + self.orientation_vecs[inds[1], 0] * text_scale_pos, + self.orientation_vecs[inds[1], 2] * text_scale_pos, + label_1, + None, + zorder=12, + ha="center", + **text_params, + ) + ax.text( + self.orientation_vecs[inds[2], 1] * text_scale_pos, + self.orientation_vecs[inds[2], 0] * text_scale_pos, + self.orientation_vecs[inds[2], 2] * text_scale_pos, + label_2, + None, + zorder=13, + ha="center", + **text_params, + ) # ax.scatter( # xs=self.g_vec_all[0,:], @@ -1118,6 +1120,7 @@ def plot_orientation_maps( dir_in_plane_degrees: float = 0.0, corr_range: np.ndarray = np.array([0, 5]), corr_normalize: bool = True, + show_legend: bool = True, scale_legend: bool = None, figsize: Union[list, tuple, np.ndarray] = (16, 5), figbound: Union[list, tuple, np.ndarray] = (0.01, 0.005), @@ -1139,6 +1142,7 @@ def plot_orientation_maps( dir_in_plane_degrees (float): In-plane angle to plot in degrees. Default is 0 / x-axis / vertical down. corr_range (np.ndarray): Correlation intensity range for the plot corr_normalize (bool): If true, set mean correlation to 1. + show_legend (bool): Show the legend scale_legend (float): 2 elements, x and y scaling of legend panel figsize (array): 2 elements defining figure size figbound (array): 2 elements defining figure boundary @@ -1413,189 +1417,190 @@ def plot_orientation_maps( ax_x.axis("off") ax_z.axis("off") - # Triangulate faces - p = self.orientation_vecs[:, (1, 0, 2)] - tri = mtri.Triangulation( - self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, - self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, - ) - # convert rgb values of pixels to faces - rgb_faces = ( - rgb_legend[tri.triangles[:, 0], :] - + rgb_legend[tri.triangles[:, 1], :] - + rgb_legend[tri.triangles[:, 2], :] - ) / 3 - # Add triangulated surface plot to axes - pc = art3d.Poly3DCollection( - p[tri.triangles], - facecolors=rgb_faces, - alpha=1, - ) - pc.set_antialiased(False) - ax_l.add_collection(pc) - - if plot_limit is None: - plot_limit = np.array( - [ - [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])], - [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])], - ] + if show_legend: + # Triangulate faces + p = self.orientation_vecs[:, (1, 0, 2)] + tri = mtri.Triangulation( + self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, + self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, ) - # plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.5 + np.mean( - # plot_limit, axis=0 - # ) - plot_limit[:, 0] = ( - plot_limit[:, 0] - np.mean(plot_limit[:, 0]) - ) * 1.5 + np.mean(plot_limit[:, 0]) - plot_limit[:, 1] = ( - plot_limit[:, 2] - np.mean(plot_limit[:, 1]) - ) * 1.5 + np.mean(plot_limit[:, 1]) - plot_limit[:, 2] = ( - plot_limit[:, 1] - np.mean(plot_limit[:, 2]) - ) * 1.1 + np.mean(plot_limit[:, 2]) - - # ax_l.view_init(elev=el, azim=az) - # Appearance - ax_l.invert_yaxis() - if swap_axes_xy_limits: - ax_l.axes.set_xlim3d(left=plot_limit[0, 0], right=plot_limit[1, 0]) - ax_l.axes.set_ylim3d(bottom=plot_limit[0, 1], top=plot_limit[1, 1]) - ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) - else: - ax_l.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1]) - ax_l.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0]) - ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) - axisEqual3D(ax_l) - if camera_dist is not None: - ax_l.dist = camera_dist - ax_l.axis("off") - - # Add text labels - text_scale_pos = 0.1 - text_params = { - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "color": "k", - "size": 14, - } - format_labels = "{0:.2g}" - vec = self.orientation_vecs[inds_legend[0], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_0[0]) - + " " - + format_labels.format(label_0[1]) - + " " - + format_labels.format(label_0[2]) - + "]", - None, - zorder=11, - ha="center", - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_0[0]) - + " " - + format_labels.format(label_0[1]) - + " " - + format_labels.format(label_0[2]) - + " " - + format_labels.format(label_0[3]) - + "]", - None, - zorder=11, - ha="center", - **text_params, - ) - vec = self.orientation_vecs[inds_legend[1], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_1[0]) - + " " - + format_labels.format(label_1[1]) - + " " - + format_labels.format(label_1[2]) - + "]", - None, - zorder=12, - ha=ha_1, - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_1[0]) - + " " - + format_labels.format(label_1[1]) - + " " - + format_labels.format(label_1[2]) - + " " - + format_labels.format(label_1[3]) - + "]", - None, - zorder=12, - ha=ha_1, - **text_params, - ) - vec = self.orientation_vecs[inds_legend[2], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_2[0]) - + " " - + format_labels.format(label_2[1]) - + " " - + format_labels.format(label_2[2]) - + "]", - None, - zorder=13, - ha=ha_2, - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_2[0]) - + " " - + format_labels.format(label_2[1]) - + " " - + format_labels.format(label_2[2]) - + " " - + format_labels.format(label_2[3]) - + "]", - None, - zorder=13, - ha=ha_2, - **text_params, + # convert rgb values of pixels to faces + rgb_faces = ( + rgb_legend[tri.triangles[:, 0], :] + + rgb_legend[tri.triangles[:, 1], :] + + rgb_legend[tri.triangles[:, 2], :] + ) / 3 + # Add triangulated surface plot to axes + pc = art3d.Poly3DCollection( + p[tri.triangles], + facecolors=rgb_faces, + alpha=1, ) + pc.set_antialiased(False) + ax_l.add_collection(pc) + + if plot_limit is None: + plot_limit = np.array( + [ + [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])], + [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])], + ] + ) + # plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.5 + np.mean( + # plot_limit, axis=0 + # ) + plot_limit[:, 0] = ( + plot_limit[:, 0] - np.mean(plot_limit[:, 0]) + ) * 1.5 + np.mean(plot_limit[:, 0]) + plot_limit[:, 1] = ( + plot_limit[:, 2] - np.mean(plot_limit[:, 1]) + ) * 1.5 + np.mean(plot_limit[:, 1]) + plot_limit[:, 2] = ( + plot_limit[:, 1] - np.mean(plot_limit[:, 2]) + ) * 1.1 + np.mean(plot_limit[:, 2]) + + # ax_l.view_init(elev=el, azim=az) + # Appearance + ax_l.invert_yaxis() + if swap_axes_xy_limits: + ax_l.axes.set_xlim3d(left=plot_limit[0, 0], right=plot_limit[1, 0]) + ax_l.axes.set_ylim3d(bottom=plot_limit[0, 1], top=plot_limit[1, 1]) + ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) + else: + ax_l.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1]) + ax_l.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0]) + ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) + axisEqual3D(ax_l) + if camera_dist is not None: + ax_l.dist = camera_dist + ax_l.axis("off") + + # Add text labels + text_scale_pos = 0.1 + text_params = { + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 14, + } + format_labels = "{0:.2g}" + vec = self.orientation_vecs[inds_legend[0], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_0[0]) + + " " + + format_labels.format(label_0[1]) + + " " + + format_labels.format(label_0[2]) + + "]", + None, + zorder=11, + ha="center", + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_0[0]) + + " " + + format_labels.format(label_0[1]) + + " " + + format_labels.format(label_0[2]) + + " " + + format_labels.format(label_0[3]) + + "]", + None, + zorder=11, + ha="center", + **text_params, + ) + vec = self.orientation_vecs[inds_legend[1], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_1[0]) + + " " + + format_labels.format(label_1[1]) + + " " + + format_labels.format(label_1[2]) + + "]", + None, + zorder=12, + ha=ha_1, + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_1[0]) + + " " + + format_labels.format(label_1[1]) + + " " + + format_labels.format(label_1[2]) + + " " + + format_labels.format(label_1[3]) + + "]", + None, + zorder=12, + ha=ha_1, + **text_params, + ) + vec = self.orientation_vecs[inds_legend[2], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_2[0]) + + " " + + format_labels.format(label_2[1]) + + " " + + format_labels.format(label_2[2]) + + "]", + None, + zorder=13, + ha=ha_2, + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_2[0]) + + " " + + format_labels.format(label_2[1]) + + " " + + format_labels.format(label_2[2]) + + " " + + format_labels.format(label_2[3]) + + "]", + None, + zorder=13, + ha=ha_2, + **text_params, + ) - plt.show() + plt.show() images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2)) if self.pymatgen_available: From f7bde15b2dc97b9a619ff572e8bb3f7e6bf5be13 Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:59:38 -0700 Subject: [PATCH 61/74] option to hide legend axes --- py4DSTEM/process/diffraction/crystal_viz.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 88cc16853..402f6f964 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -1601,6 +1601,8 @@ def plot_orientation_maps( ) plt.show() + else: + ax_l.set_axis_off() images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2)) if self.pymatgen_available: From b9e125b26a24a6105d2ed5483de3859bcf31d876 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Thu, 25 Jul 2024 13:30:18 -0400 Subject: [PATCH 62/74] versions to 0.14.16 --- py4DSTEM/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index bbe70be03..a9f73baeb 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.15" +__version__ = "0.14.16" From 0c6aa4d247a7df4824b80c9cc5c6d948d6e9d925 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 5 Aug 2024 16:32:03 -0700 Subject: [PATCH 63/74] silly BO bug --- py4DSTEM/process/phase/parameter_optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 149e8143b..17101256a 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -458,7 +458,7 @@ def _split_static_and_optimization_vars(self, argdict): return static_args, optimization_args def _get_scan_positions(self, affine_transform, dataset): - scan_positions = self._init_static_args.pop("initial_scan_positions", None) + scan_positions = self._init_static_args.get("initial_scan_positions", None) if scan_positions is None: R_pixel_size = dataset.calibration.get_R_pixel_size() x, y = ( From 502c25f8ae135fee53377d7b068e8653a17aac90 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 6 Aug 2024 14:37:28 -0700 Subject: [PATCH 64/74] make TV and entropy measures array dims invariant --- py4DSTEM/process/phase/parameter_optimize.py | 50 +++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 17101256a..886fd7972 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -135,7 +135,7 @@ def grid_search( Parameters ---------- - n_initial_points: int + n_points: int Number of uniformly spaced trial points to run on a grid error_metric: Callable or str Function used to compute the reconstruction error. @@ -233,7 +233,7 @@ def evaluation_callback(ptycho): ax.imshow(res[0], cmap=cmap) title_substrings = [ - f"{param.name}: {val}" + f"{param.name}: {val:.3e}" for param, val in zip(self._parameter_list, params) ] title_substrings.append(f"error: {res[1]:.3e}") @@ -485,8 +485,10 @@ def _get_error_metric(self, error_metric: Union[Callable, str]) -> Callable: "log-converged", "linear-converged", "TV", + "TV-phase", "std", "std-phase", + "entropy", "entropy-phase", ), f"Error metric {error_metric} not recognized." @@ -519,10 +521,20 @@ def f(ptycho): elif error_metric == "TV": def f(ptycho): - gx, gy = np.gradient(ptycho.object_cropped, axis=(-2, -1)) - obj_mag = np.sum(np.abs(ptycho.object_cropped)) + array = np.abs(ptycho.object_cropped) + gx = array[..., 1:, :] - array[..., -1:, :] + gy = array[..., :, 1:] - array[..., :, -1:] + tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy)) + return tv / array.size + + elif error_metric == "TV-phase": + + def f(ptycho): + array = np.angle(ptycho.object_cropped) + gx = array[..., 1:, :] - array[..., -1:, :] + gy = array[..., :, 1:] - array[..., :, -1:] tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy)) - return tv / obj_mag + return tv / array.size elif error_metric == "std": @@ -534,16 +546,30 @@ def f(ptycho): def f(ptycho): return -np.std(np.angle(ptycho.object_cropped)) + elif error_metric == "entropy": + + def f(ptycho): + array = np.abs(ptycho.object_cropped) + normalized_array = (array - np.min(array)) / np.ptp(array) + # gx = normalized_array[..., 1:, :] - normalized_array[..., -1:, :] + # gy = normalized_array[..., :, 1:] - normalized_array[..., :, -1:] + gx, gy = np.gradient(normalized_array, axis=(-2, -1)) + ghist, _, _ = np.histogram2d(gx.ravel(), gy.ravel(), bins=array.shape) + ghist = ghist[ghist > 0] / array.size + S = np.sum(ghist * np.log2(ghist)) + return S + elif error_metric == "entropy-phase": def f(ptycho): - obj = np.angle(ptycho.object_cropped) - gx, gy = np.gradient(obj) - ghist, _, _ = np.histogram2d( - gx.ravel(), gy.ravel(), bins=obj.shape, density=True - ) - nz = ghist > 0 - S = np.sum(ghist[nz] * np.log2(ghist[nz])) + array = np.angle(ptycho.object_cropped) + normalized_array = (array - np.min(array)) / np.ptp(array) + # gx = normalized_array[..., 1:, :] - normalized_array[..., -1:, :] + # gy = normalized_array[..., :, 1:] - normalized_array[..., :, -1:] + gx, gy = np.gradient(normalized_array, axis=(-2, -1)) + ghist, _, _ = np.histogram2d(gx.ravel(), gy.ravel(), bins=array.shape) + ghist = ghist[ghist > 0] / array.size + S = np.sum(ghist * np.log2(ghist)) return S else: From dcfa44ab04719ae5113f1cb8e64a2ffd380ed8ec Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 6 Aug 2024 16:47:30 -0700 Subject: [PATCH 65/74] in place datacube mod --- .../magnetic_ptychographic_tomography.py | 6 ++ .../process/phase/magnetic_ptychography.py | 6 ++ .../mixedstate_multislice_ptychography.py | 10 ++- .../process/phase/mixedstate_ptychography.py | 10 ++- .../process/phase/multislice_ptychography.py | 10 ++- py4DSTEM/process/phase/phase_base_class.py | 74 +++++++++---------- .../process/phase/ptychographic_methods.py | 3 +- .../process/phase/ptychographic_tomography.py | 5 ++ .../process/phase/singleslice_ptychography.py | 10 ++- .../phase/xray_magnetic_ptychography.py | 5 ++ 10 files changed, 94 insertions(+), 45 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index be6332c74..8265c1325 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -225,6 +225,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_probe_overlaps: bool = True, rotation_real_space_degrees: float = None, @@ -266,6 +267,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_probe_overlaps: bool, optional @@ -479,12 +483,14 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index aa89d09b3..ff9b2feea 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -208,6 +208,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_rotation: bool = True, maximize_divergence: bool = False, @@ -259,6 +260,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_rotation: bool, optional @@ -551,12 +555,14 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index d82a37eb4..3bacf1870 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -267,6 +267,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -318,6 +319,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -486,17 +490,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 7bbadf114..9b12d09e0 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -213,6 +213,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -264,6 +265,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -432,17 +436,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 87e3c1fe4..65a347b83 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -241,6 +241,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -292,6 +293,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -460,17 +464,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 27476cb43..d89ffd086 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1351,6 +1351,7 @@ def _normalize_diffraction_intensities( com_fitted_y, positions_mask, crop_patterns, + in_place_datacube_modification, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1363,78 +1364,67 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient - positions_mask: np.ndarray, optional + positions_mask: np.ndarray Boolean real space mask to select positions in datacube to skip for reconstruction crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns - when centering + If True, patterns are cropped to avoid wrap around of patterns + in_place_datacube_modification: bool + If True, the diffraction intensities are modified in-place Returns ------- - amplitudes: (Rx * Ry, Sx, Sy) np.ndarray + diffraction_intensities: (Rx * Ry, Sx, Sy) np.ndarray Flat array of normalized diffraction amplitudes mean_intensity: float Mean intensity value + crop_mask + Mask to crop diffraction patterns with """ # explicit read-only self attributes up-front asnumpy = self._asnumpy mean_intensity = 0 - - diffraction_intensities = asnumpy(diffraction_intensities) com_fitted_x = asnumpy(com_fitted_x) com_fitted_y = asnumpy(com_fitted_y) - if positions_mask is not None: - number_of_patterns = np.count_nonzero(positions_mask.ravel()) + if in_place_datacube_modification: + diff_intensities = diffraction_intensities else: - number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + diff_intensities = diffraction_intensities.copy() # Aggressive cropping for when off-centered high scattering angle data was recorded if crop_patterns: crop_x = int( np.minimum( - diffraction_intensities.shape[2] - com_fitted_x.max(), + diff_intensities.shape[2] - com_fitted_x.max(), com_fitted_x.min(), ) ) crop_y = int( np.minimum( - diffraction_intensities.shape[3] - com_fitted_y.max(), + diff_intensities.shape[3] - com_fitted_y.max(), com_fitted_y.min(), ) ) crop_w = np.minimum(crop_y, crop_x) - diffraction_intensities_shape_crop = (crop_w * 2, crop_w * 2) - amplitudes = np.zeros( - ( - number_of_patterns, - crop_w * 2, - crop_w * 2, - ), - dtype=np.float32, - ) - crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask = np.zeros(diff_intensities.shape[-2:], dtype="bool") crop_mask[:crop_w, :crop_w] = True crop_mask[-crop_w:, :crop_w] = True crop_mask[:crop_w:, -crop_w:] = True crop_mask[-crop_w:, -crop_w:] = True + crop_mask_shape = (crop_w * 2, crop_w * 2) + else: crop_mask = None - diffraction_intensities_shape_crop = diffraction_intensities.shape[-2:] - amplitudes = np.zeros( - (number_of_patterns,) + diffraction_intensities_shape_crop, - dtype=np.float32, - ) + crop_mask_shape = diff_intensities.shape[-2:] - counter = 0 for rx, ry in tqdmnd( - diffraction_intensities.shape[0], - diffraction_intensities.shape[1], + diff_intensities.shape[0], + diff_intensities.shape[1], desc="Normalizing amplitudes", unit="probe position", disable=not self._verbose, @@ -1442,28 +1432,32 @@ def _normalize_diffraction_intensities( if positions_mask is not None: if not positions_mask[rx, ry]: continue + intensities = get_shifted_ar( - diffraction_intensities[rx, ry], + diff_intensities[rx, ry], -com_fitted_x[rx, ry], -com_fitted_y[rx, ry], bilinear=True, device="cpu", ) - if crop_patterns: - intensities = intensities[crop_mask].reshape( - diffraction_intensities_shape_crop - ) - mean_intensity += np.sum(intensities) - amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) - counter += 1 + diff_intensities[rx, ry] = np.sqrt(np.maximum(intensities, 0)) - mean_intensity /= amplitudes.shape[0] + if positions_mask is not None: + diff_intensities = diff_intensities[positions_mask] + else: + qx, qy = diff_intensities.shape[-2:] + diff_intensities = diff_intensities.reshape((-1, qx, qy)) + + if crop_patterns: + diff_intensities = diff_intensities[:, crop_mask].reshape( + (-1,) + crop_mask_shape + ) - self._diffraction_intensities_shape_crop = diffraction_intensities_shape_crop + mean_intensity /= diff_intensities.shape[0] - return amplitudes, mean_intensity, crop_mask + return diff_intensities, mean_intensity, crop_mask, crop_mask_shape def show_complex_CoM( self, diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index b9eae9385..2e47a5e23 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1022,6 +1022,7 @@ def _initialize_probe( device = self._device crop_mask = self._crop_mask + crop_mask_shape = self._crop_mask_shape region_of_interest_shape = self._region_of_interest_shape sampling = self.sampling energy = self._energy @@ -1049,7 +1050,7 @@ def _initialize_probe( if crop_patterns: vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( - self._diffraction_intensities_shape_crop + crop_mask_shape ) sx, sy = vacuum_probe_intensity.shape diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 3639096dc..037ef4849 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -219,6 +219,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_probe_overlaps: bool = True, rotation_real_space_degrees: float = None, @@ -261,6 +262,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_probe_overlaps: bool, optional @@ -478,6 +482,7 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index b220ba741..d391dd293 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -195,6 +195,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -236,6 +237,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -405,17 +409,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 91c0bbfaa..1d9120dbf 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -206,6 +206,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_rotation: bool = True, maximize_divergence: bool = False, @@ -257,6 +258,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_rotation: bool, optional @@ -549,6 +553,7 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, From e8c04ac8b61ef83e5cc77e634472354907a41d99 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 8 Aug 2024 11:43:47 -0700 Subject: [PATCH 66/74] small bugfixes --- py4DSTEM/process/phase/phase_base_class.py | 1 + py4DSTEM/process/phase/xray_magnetic_ptychography.py | 1 + 2 files changed, 2 insertions(+) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index d89ffd086..c571dcd3d 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1551,6 +1551,7 @@ def to_h5(self, group): "semiangle_cutoff": self._semiangle_cutoff, "rolloff": self._rolloff, "object_padding_px": self._object_padding_px, + "object_fov_ang": self._object_fov_ang, "object_type": self._object_type, "verbose": self._verbose, "device": self._device, diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 1d9120dbf..91128743d 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -560,6 +560,7 @@ def preprocess( com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) From 60f74d94b8076af48dc5687188544acc34b862de Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Thu, 8 Aug 2024 20:08:19 -0400 Subject: [PATCH 67/74] versions dev to 0.14.17 --- py4DSTEM/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index a9f73baeb..b9decac47 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.16" +__version__ = "0.14.17" From 39ba85e537e0bd67975af32515eabefa52e2ce92 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 10 Aug 2024 09:07:28 -0700 Subject: [PATCH 68/74] remove equal datacube shapes check for magnetic ptychography --- py4DSTEM/process/phase/magnetic_ptychography.py | 4 ---- py4DSTEM/process/phase/xray_magnetic_ptychography.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index ff9b2feea..975f6ac84 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -377,10 +377,6 @@ def preprocess( f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." ) - dc_shapes = [dc.shape for dc in self._datacube] - if dc_shapes.count(dc_shapes[0]) != self._num_measurements: - raise ValueError("datacube intensities must be the same size.") - if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 91128743d..b1b8a5862 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -375,10 +375,6 @@ def preprocess( f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." ) - dc_shapes = [dc.shape for dc in self._datacube] - if dc_shapes.count(dc_shapes[0]) != self._num_measurements: - raise ValueError("datacube intensities must be the same size.") - if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") From 65b59825d3c3ea1ce9f40b89db7f094f8751759e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 10 Aug 2024 12:16:08 -0700 Subject: [PATCH 69/74] parallax improvements: added batch_size, fixed inv_mask, added aligned_bf_guess --- py4DSTEM/process/phase/parallax.py | 157 +++++++++++++++++++---------- 1 file changed, 102 insertions(+), 55 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index fa86f4dc2..019ab6e6a 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -19,6 +19,7 @@ AffineTransform, bilinear_kernel_density_estimate, bilinearly_interpolate_array, + generate_batches, lanczos_interpolate_array, lanczos_kernel_density_estimate, pixel_rolling_kernel_density_estimate, @@ -260,12 +261,14 @@ def preprocess( descan_correction_fit_function: str = None, defocus_guess: float = None, rotation_guess: float = None, + aligned_bf_image_guess: np.ndarray = None, plot_average_bf: bool = True, realspace_mask: np.ndarray = None, apply_realspace_mask_to_stack: bool = True, vectorized_com_calculation: bool = True, device: str = None, clear_fft_cache: bool = None, + max_batch_size: int = None, store_initial_arrays: bool = True, **kwargs, ): @@ -308,6 +311,8 @@ def preprocess( if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation store_initial_arrays: bool, optional If True, stores a copy of the arrays necessary to reinitialize in reconstruct @@ -474,7 +479,6 @@ def preprocess( self._stack_BF_unshifted = xp.ones(stack_shape, xp.float32) if normalize_order == 0: - # all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] weights = xp.average( all_bfs.reshape((self._num_bf_images, -1)), weights=self._window_edge.ravel(), @@ -517,7 +521,6 @@ def preprocess( weights = np.sqrt(self._window_edge).ravel() for a0 in range(all_bfs.shape[0]): - # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) # weighted least squares coefs = np.linalg.lstsq( weights[:, None] * basis, @@ -574,7 +577,6 @@ def preprocess( weights = np.sqrt(self._window_edge).ravel() for a0 in range(all_bfs.shape[0]): - # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) # weighted least squares coefs = np.linalg.lstsq( weights[:, None] * basis, @@ -645,49 +647,84 @@ def preprocess( # Initialization utilities self._stack_mask = xp.tile(self._window_pad[None], (self._num_bf_images, 1, 1)) + + if max_batch_size is None: + max_batch_size = self._num_bf_images + + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) + if defocus_guess is not None: - Gs = xp.fft.fft2(self._stack_BF_shifted) + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted[start:end] + probe_angles = self._probe_angles[start:end] + stack_mask = self._stack_mask[start:end] - self._xy_shifts = ( - -self._probe_angles - * defocus_guess - / xp.array(self._scan_sampling, dtype=xp.float32) - ) + Gs = xp.fft.fft2(shifted_BFs) - if rotation_guess: - angle = xp.deg2rad(rotation_guess) - rotation_matrix = xp.array( - [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]], - dtype=xp.float32, + xy_shifts = ( + -probe_angles + * defocus_guess + / xp.array(self._scan_sampling, dtype=xp.float32) ) - self._xy_shifts = xp.dot(self._xy_shifts, rotation_matrix) - dx = self._xy_shifts[:, 0] - dy = self._xy_shifts[:, 1] + if rotation_guess is not None: + angle = xp.deg2rad(rotation_guess) + rotation_matrix = xp.array( + [ + [np.cos(angle), np.sin(angle)], + [-np.sin(angle), np.cos(angle)], + ], + dtype=xp.float32, + ) + xy_shifts = xp.dot(xy_shifts, rotation_matrix) - shift_op = xp.exp( - self._qx_shift[None] * dx[:, None, None] - + self._qy_shift[None] * dy[:, None, None] - ) - self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) - self._stack_mask = xp.real( - xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) - ) + dx = xy_shifts[:, 0] + dy = xy_shifts[:, 1] - del Gs - else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + stack_mask = xp.real(xp.fft.ifft2(xp.fft.fft2(stack_mask) * shift_op)) + + self._xy_shifts[start:end] = xy_shifts + self._stack_BF_shifted[start:end] = stack_BF_shifted + self._stack_mask[start:end] = stack_mask + + del Gs self._stack_mean = xp.mean(self._stack_BF_shifted) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images - self._recon_mask = xp.sum(self._stack_mask, axis=0) + self._recon_mask = xp.mean(self._stack_mask, axis=0) mask_inv = 1 - xp.clip(self._recon_mask, 0, 1) - self._recon_BF = ( - self._stack_mean * mask_inv - + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) - ) / (self._recon_mask + mask_inv) + if aligned_bf_image_guess is not None: + aligned_bf_image_guess = xp.asarray(aligned_bf_image_guess) + if normalize_images: + self._recon_BF = xp.ones(stack_shape[-2:], dtype=xp.float32) + aligned_bf_image_guess /= aligned_bf_image_guess.mean() + else: + self._recon_BF = xp.full(stack_shape[-2:], self._stack_mean) + + self._recon_BF[ + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = ( + self._window_inv * self._stack_mean + + self._window_edge * aligned_bf_image_guess + ) + + else: + self._recon_BF = ( + self._stack_mean * mask_inv + + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) + ) / (self._recon_mask + mask_inv) self._recon_error = ( xp.atleast_1d( @@ -697,6 +734,7 @@ def preprocess( ) ) / self._mask_sum + / self._stack_mean ) if store_initial_arrays: @@ -754,6 +792,7 @@ def reconstruct( reset: bool = None, device: str = None, clear_fft_cache: bool = None, + max_batch_size: int = None, **kwargs, ): """ @@ -788,6 +827,8 @@ def reconstruct( If True, the reconstruction is reset device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation clear_fft_cache: bool, optional if true, and device = 'gpu', clears the cached fft plan at the end of function calls @@ -911,6 +952,9 @@ def reconstruct( xy_center = (self._xy_inds - xp.median(self._xy_inds, axis=0)).astype("float") + if max_batch_size is None: + max_batch_size = self._num_bf_images + # Loop over all binning values for a0 in range(bin_vals.shape[0]): G_ref = xp.fft.fft2(self._recon_BF) @@ -981,31 +1025,33 @@ def reconstruct( shifts_update = xy_shifts_fit - self._xy_shifts # apply shifts - Gs = xp.fft.fft2(self._stack_BF_shifted) + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted[start:end] + stack_mask = self._stack_mask[start:end] - dx = shifts_update[:, 0] - dy = shifts_update[:, 1] - self._xy_shifts[:, 0] += dx - self._xy_shifts[:, 1] += dy + Gs = xp.fft.fft2(shifted_BFs) - shift_op = xp.exp( - self._qx_shift[None] * dx[:, None, None] - + self._qy_shift[None] * dy[:, None, None] - ) + dx = shifts_update[start:end, 0] + dy = shifts_update[start:end, 1] - self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) - self._stack_mask = xp.real( - xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) - ) + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) - self._stack_BF_shifted = xp.asarray( - self._stack_BF_shifted, dtype=xp.float32 - ) # numpy fft upcasts? - self._stack_mask = xp.asarray( - self._stack_mask, dtype=xp.float32 - ) # numpy fft upcasts? + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + stack_mask = xp.real(xp.fft.ifft2(xp.fft.fft2(stack_mask) * shift_op)) + + self._stack_BF_shifted[start:end] = xp.asarray( + stack_BF_shifted, dtype=xp.float32 + ) + self._stack_mask[start:end] = xp.asarray(stack_mask, dtype=xp.float32) + self._xy_shifts[start:end, 0] += dx + self._xy_shifts[start:end, 1] += dy - del Gs + del Gs # Center the shifts xy_shifts_median = xp.round(xp.median(self._xy_shifts, axis=0)).astype(int) @@ -1016,12 +1062,12 @@ def reconstruct( self._stack_mask = xp.roll(self._stack_mask, -xy_shifts_median, axis=(1, 2)) # Generate new estimate - self._recon_mask = xp.sum(self._stack_mask, axis=0) + self._recon_mask = xp.mean(self._stack_mask, axis=0) mask_inv = 1 - np.clip(self._recon_mask, 0, 1) self._recon_BF = ( self._stack_mean * mask_inv - + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) + + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) ) / (self._recon_mask + mask_inv) self._recon_error = ( @@ -1032,6 +1078,7 @@ def reconstruct( ) ) / self._mask_sum + / self._stack_mean ) self.error_iterations.append(float(self._recon_error)) From cd1cf4413edf48419740118f0b5724e928287250 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 10 Aug 2024 13:54:41 -0700 Subject: [PATCH 70/74] switching to arrows overlay comparison --- py4DSTEM/process/phase/parallax.py | 126 ++++++++++++++--------------- 1 file changed, 59 insertions(+), 67 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 019ab6e6a..241748ebc 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2249,19 +2249,6 @@ def score_CTF(coefs): # Plot the measured/fitted shifts comparison if plot_BF_shifts_comparison: - measured_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 0] - ) - - measured_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 1] - ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2269,53 +2256,28 @@ def score_CTF(coefs): .T ) - fitted_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 0] - ) + scale_arrows = kwargs.pop("scale_arrows", 1) + plot_arrow_freq = kwargs.pop("plot_arrow_freq", 1) + figsize = kwargs.pop("figsize", (4, 4)) - fitted_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 1] - ) + fig, ax = plt.subplots(figsize=figsize) - max_shift = xp.max( - xp.array( - [ - xp.abs(measured_shifts_sx).max(), - xp.abs(measured_shifts_sy).max(), - xp.abs(fitted_shifts_sx).max(), - xp.abs(fitted_shifts_sy).max(), - ] - ) + self.show_shifts( + shifts_ang=self._xy_shifts_Ang, + plot_rotated_shifts=False, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + color=(1, 0, 0, 0.5), + figax=(fig, ax), ) - axsize = kwargs.pop("axsize", (4, 4)) - cmap = kwargs.pop("cmap", "PiYG") - vmin = kwargs.pop("vmin", -max_shift) - vmax = kwargs.pop("vmax", max_shift) - - show( - [ - [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], - [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], - ], - cmap=cmap, - vmin=vmin, - vmax=vmax, - intensity_range="absolute", - axsize=axsize, - ticks=False, - title=[ - "Measured Vertical Shifts", - "Fitted Vertical Shifts", - "Measured Horizontal Shifts", - "Fitted Horizontal Shifts", - ], + self.show_shifts( + shifts_ang=fitted_shifts, + plot_rotated_shifts=False, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + color=(0, 0, 1, 0.5), + figax=(fig, ax), ) # Plot the CTF comparison between experiment and fit @@ -2873,9 +2835,11 @@ def _visualize_figax( def show_shifts( self, + shifts_ang=None, scale_arrows=1, plot_arrow_freq=1, plot_rotated_shifts=True, + figax=None, **kwargs, ): """ @@ -2883,31 +2847,58 @@ def show_shifts( Parameters ---------- + shifts_ang: np.ndarray, optional + If None, self._xy_shifts is used scale_arrows: float, optional Scale to multiply shifts by plot_arrow_freq: int, optional Frequency of shifts to plot in quiver plot + plot_rotated_shifts: bool, optional + If True, shifts are plotted with the relative rotation decomposed + figax: optional + Tuple of figure, axes to plot against """ xp = self._xp asnumpy = self._asnumpy color = kwargs.pop("color", (1, 0, 0, 1)) + + if shifts_ang is None: + shifts_px = self._xy_shifts + else: + shifts_px = shifts_ang / xp.array(self._scan_sampling) + + shifts = shifts_px * scale_arrows * xp.array(self._reciprocal_sampling) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): - figsize = kwargs.pop("figsize", (8, 4)) - fig, ax = plt.subplots(1, 2, figsize=figsize) - scaling_factor = ( - xp.array(self._reciprocal_sampling) - / xp.array(self._scan_sampling) - * scale_arrows + + if figax is None: + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + else: + fig, ax = figax + + rotated_color = kwargs.pop("rotated_color", (0, 0, 0, 1)) + + if shifts_ang is None: + rotated_shifts_px = self._xy_shifts.copy() + else: + rotated_shifts_px = shifts_ang / xp.array(self._scan_sampling) + + if self.transpose: + rotated_shifts_px = xp.flip(rotated_shifts_px, axis=1) + + rotated_shifts = ( + rotated_shifts_px * scale_arrows * xp.array(self._reciprocal_sampling) ) - rotated_shifts = self._xy_shifts_Ang * scaling_factor else: - figsize = kwargs.pop("figsize", (4, 4)) - fig, ax = plt.subplots(figsize=figsize) - - shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] + if figax is None: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + else: + fig, ax = figax dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -2950,6 +2941,7 @@ def show_shifts( angles="xy", scale_units="xy", scale=1, + color=rotated_color, **kwargs, ) From e85ac13f5b16d6e2476a580ba51e88c62cd9f8a6 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 11 Aug 2024 12:08:55 -0700 Subject: [PATCH 71/74] adding common aberrations function, easy to widgetify --- py4DSTEM/process/phase/parallax.py | 224 ++++++++++++++++++++--------- py4DSTEM/process/phase/utils.py | 79 ++++++++++ 2 files changed, 232 insertions(+), 71 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 241748ebc..20b103e7e 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -19,6 +19,7 @@ AffineTransform, bilinear_kernel_density_estimate, bilinearly_interpolate_array, + calculate_aberration_gradient_basis, generate_batches, lanczos_interpolate_array, lanczos_kernel_density_estimate, @@ -26,7 +27,7 @@ ) from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from py4DSTEM.visualize import return_scaled_histogram_ordering, show +from py4DSTEM.visualize import return_scaled_histogram_ordering from scipy.linalg import polar from scipy.ndimage import distance_transform_edt from scipy.optimize import minimize @@ -287,7 +288,6 @@ def preprocess( If True, bright images normalized to have a mean of 1 normalize_order: integer, optional Polynomial order for normalization. 0 means constant, 1 means linear, etc. - Higher orders not yet implemented. defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus @@ -775,6 +775,142 @@ def preprocess( return self + def guess_common_aberrations_and_rotation( + self, + rotation_angle_deg=0, + defocus=0, + astigmatism=0, + astigmatism_angle_deg=0, + coma=0, + coma_angle_deg=0, + spherical_aberration=0, + max_batch_size=None, + plot_arrow_freq=1, + scale_arrows=1, + **kwargs, + ): + """ """ + xp = self._xp + asnumpy = self._asnumpy + + if not hasattr(self, "_recon_BF"): + raise ValueError( + ( + "Aberration guessing is meant to be ran after preprocessing. " + "Please run the `preprocess()` function first." + ) + ) + + # aberrations_coefs + aberrations_mn = [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 1], + [2, 1, 0], + [2, 1, 1], + [3, 0, 0], + ] + astigmatism_x = astigmatism * np.cos(np.deg2rad(astigmatism_angle_deg) * 2) + astigmatism_y = astigmatism * np.sin(np.deg2rad(astigmatism_angle_deg) * 2) + coma_x = coma * np.cos(np.deg2rad(coma_angle_deg) * 1) + coma_y = coma * np.sin(np.deg2rad(coma_angle_deg) * 1) + aberrations_coefs = xp.array( + [ + -defocus, + astigmatism_x, + astigmatism_y, + coma_x, + coma_y, + spherical_aberration, + ] + ) + + # aberrations_basis + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + aberrations_basis, aberrations_basis_du, aberrations_basis_dv = ( + calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=np.deg2rad(rotation_angle_deg), + xp=xp, + ) + ) + + # shifts + corner_indices = self._xy_inds - xp.array(self._region_of_interest_shape // 2) + raveled_indices = xp.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.array( + ( + aberrations_basis_du[raveled_indices, :], + aberrations_basis_dv[raveled_indices, :], + ) + ) + shifts_ang = xp.tensordot(gradients, aberrations_coefs, axes=1).T + shifts_px = shifts_ang / xp.array(self._scan_sampling) + + # shifted stack + aligned_stack = xp.zeros_like(self._stack_BF_shifted_initial[0]) + if max_batch_size is None: + max_batch_size = self._num_bf_images + + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted_initial[start:end] + + Gs = xp.fft.fft2(shifted_BFs) + + dx = shifts_px[start:end, 0] + dy = shifts_px[start:end, 1] + + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + aligned_stack += stack_BF_shifted.sum(0) + + cropped_stack = asnumpy( + self._crop_padded_object(aligned_stack, upsampled=False) + ) + + figsize = kwargs.pop("figsize", (8, 4)) + color = kwargs.pop("color", (1, 0, 0, 1)) + cmap = kwargs.pop("cmap", "magma") + + fig, axs = plt.subplots(1, 2, figsize=figsize) + + self.show_shifts( + shifts_ang=shifts_ang, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + plot_rotated_shifts=False, + color=color, + figax=(fig, axs[0]), + ) + + axs[0].set_title("Predicted BF Shifts") + + extent = [ + 0, + self._scan_sampling[1] * cropped_stack.shape[1], + self._scan_sampling[0] * cropped_stack.shape[0], + 0, + ] + + axs[1].imshow(cropped_stack, cmap=cmap, extent=extent, **kwargs) + axs[1].set_ylabel("x [A]") + axs[1].set_xlabel("y [A]") + axs[1].set_title("Predicted Aligned BF Image") + + fig.tight_layout() + def reconstruct( self, max_alignment_bin: int = None, @@ -2092,75 +2228,21 @@ def calculate_CTF_FFT(alpha_shape, *coefs): # Direct Shifts Fitting if fit_BF_shifts: - # FFT coordinates - sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) - sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) - qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) - qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) - qx, qy = np.meshgrid(qx, qy, indexing="ij") - - # passive rotation basis by -theta - rotation_angle = -self.rotation_Q_to_R_rads - qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( - rotation_angle - ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) - - qr2 = qx**2 + qy**2 - u = qx * self._wavelength - v = qy * self._wavelength - alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy, qx) - - # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) - self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) - self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) - for a0 in range(self._aberrations_num): - m, n, a = self._aberrations_mn[a0] - - if n == 0: - # Radially symmetric basis - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() - self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() - - elif a == 0: - # cos coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) - / (m + 1) - ).ravel() - self._aberrations_basis_dv[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) - / (m + 1) - ).ravel() - - else: - # sin coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) - / (m + 1) - ).ravel() - self._aberrations_basis_dv[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) - / (m + 1) - ).ravel() - - # global scaling - self._aberrations_basis *= 2 * np.pi / self._wavelength - self._aberrations_surface_shape = alpha.shape + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + ( + self._aberrations_babis, + self._aberrations_basis_du, + self._aberrations_basis_dv, + ) = calculate_aberration_gradient_basis( + self._aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=self.rotation_Q_to_R_rads, + xp=xp, + ) # CTF function def calculate_CTF(alpha_shape, *coefs): diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index f932b40b5..bb960da62 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1504,6 +1504,85 @@ def step_model(radius, sig_0, rad_0, width): return probe_corr, polar_int, polar_int_corr, coefs_all +def calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + gpts, + wavelength, + rotation_angle=0, + xp=np, +): + """ """ + sx, sy = sampling + nx, ny = gpts + qx = xp.fft.fftfreq(nx, sx) + qy = xp.fft.fftfreq(ny, sy) + qx, qy = xp.meshgrid(qx, qy, indexing="ij") + + # passive rotation + qx, qy = qx * xp.cos(-rotation_angle) + qy * xp.sin(-rotation_angle), -qx * xp.sin( + -rotation_angle + ) + qy * xp.cos(-rotation_angle) + + # coordinate system + qr2 = qx**2 + qy**2 + u = qx * wavelength + v = qy * wavelength + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy, qx) + + _aberrations_n = len(aberrations_mn) + _aberrations_basis = xp.zeros((alpha.size, _aberrations_n)) + _aberrations_basis_du = xp.zeros((alpha.size, _aberrations_n)) + _aberrations_basis_dv = xp.zeros((alpha.size, _aberrations_n)) + + for a0 in range(_aberrations_n): + m, n, a = aberrations_mn[a0] + + if n == 0: + # Radially symmetric basis + _aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + _aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + _aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() + + elif a == 0: + # cos coef + _aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + _aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + _aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() + + else: + # sin coef + _aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + _aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + _aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() + + # global scaling + _aberrations_basis *= 2 * np.pi / wavelength + + return _aberrations_basis, _aberrations_basis_du, _aberrations_basis_dv + + def aberrations_basis_function( probe_size, probe_sampling, From 9288d1ac50419e803350a0a6ae7f17d4e7388c88 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 11 Aug 2024 14:27:06 -0700 Subject: [PATCH 72/74] added transpose and upsampling options --- py4DSTEM/process/phase/parallax.py | 97 +++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 23 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 20b103e7e..14a873539 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -775,9 +775,14 @@ def preprocess( return self - def guess_common_aberrations_and_rotation( + def guess_common_aberrations( self, rotation_angle_deg=0, + transpose=False, + kde_upsample_factor=None, + kde_sigma_px=0.125, + kde_lowpass_filter=False, + lanczos_interpolation_order=None, defocus=0, astigmatism=0, astigmatism_angle_deg=0, @@ -825,6 +830,10 @@ def guess_common_aberrations_and_rotation( ] ) + # transpose rotation matrix + if transpose: + rotation_angle_deg *= -1 + # aberrations_basis sampling = 1 / ( np.array(self._reciprocal_sampling) * self._region_of_interest_shape @@ -852,33 +861,75 @@ def guess_common_aberrations_and_rotation( ) ) shifts_ang = xp.tensordot(gradients, aberrations_coefs, axes=1).T + + # transpose predicted shifts + if transpose: + shifts_ang = xp.flip(shifts_ang, axis=1) + shifts_px = shifts_ang / xp.array(self._scan_sampling) - # shifted stack - aligned_stack = xp.zeros_like(self._stack_BF_shifted_initial[0]) - if max_batch_size is None: - max_batch_size = self._num_bf_images + # upsampled stack + if kde_upsample_factor is not None: + BF_size = np.array(self._stack_BF_unshifted.shape[-2:]) + pixel_output_shape = np.round(BF_size * kde_upsample_factor).astype("int") - for start, end in generate_batches( - self._num_bf_images, max_batch=max_batch_size - ): - shifted_BFs = self._stack_BF_shifted_initial[start:end] + x = xp.arange(BF_size[0], dtype=xp.float32) + y = xp.arange(BF_size[1], dtype=xp.float32) + xa_init, ya_init = xp.meshgrid(x, y, indexing="ij") - Gs = xp.fft.fft2(shifted_BFs) + # kernel density output the upsampled BF image + xa = (xa_init + shifts_px[:, 0, None, None]) * kde_upsample_factor + ya = (ya_init + shifts_px[:, 1, None, None]) * kde_upsample_factor - dx = shifts_px[start:end, 0] - dy = shifts_px[start:end, 1] + pix_output = self._kernel_density_estimate( + xa, + ya, + self._stack_BF_unshifted, + pixel_output_shape, + kde_sigma_px * kde_upsample_factor, + lanczos_alpha=lanczos_interpolation_order, + lowpass_filter=kde_lowpass_filter, + ) - shift_op = xp.exp( - self._qx_shift[None] * dx[:, None, None] - + self._qy_shift[None] * dy[:, None, None] + # hack since cropping requires "_kde_upsample_factor" + old_upsample_factor = getattr(self, "_kde_upsample_factor", None) + self._kde_upsample_factor = kde_upsample_factor + cropped_image = asnumpy( + self._crop_padded_object(pix_output, upsampled=True) ) - stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) - aligned_stack += stack_BF_shifted.sum(0) + if old_upsample_factor is not None: + self._kde_upsample_factor = old_upsample_factor + else: + del self._kde_upsample_factor - cropped_stack = asnumpy( - self._crop_padded_object(aligned_stack, upsampled=False) - ) + # shifted stack + else: + kde_upsample_factor = 1 + aligned_stack = xp.zeros_like(self._stack_BF_shifted_initial[0]) + + if max_batch_size is None: + max_batch_size = self._num_bf_images + + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted_initial[start:end] + + Gs = xp.fft.fft2(shifted_BFs) + + dx = shifts_px[start:end, 0] + dy = shifts_px[start:end, 1] + + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + aligned_stack += stack_BF_shifted.sum(0) + + cropped_image = asnumpy( + self._crop_padded_object(aligned_stack, upsampled=False) + ) figsize = kwargs.pop("figsize", (8, 4)) color = kwargs.pop("color", (1, 0, 0, 1)) @@ -899,12 +950,12 @@ def guess_common_aberrations_and_rotation( extent = [ 0, - self._scan_sampling[1] * cropped_stack.shape[1], - self._scan_sampling[0] * cropped_stack.shape[0], + self._scan_sampling[1] * cropped_image.shape[1] / kde_upsample_factor, + self._scan_sampling[0] * cropped_image.shape[0] / kde_upsample_factor, 0, ] - axs[1].imshow(cropped_stack, cmap=cmap, extent=extent, **kwargs) + axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs) axs[1].set_ylabel("x [A]") axs[1].set_xlabel("y [A]") axs[1].set_title("Predicted Aligned BF Image") From e9fff8d677c48ad9c880f8db1982638a992f0b30 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 11 Aug 2024 14:43:48 -0700 Subject: [PATCH 73/74] added docstrings --- py4DSTEM/process/phase/parallax.py | 48 +++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 14a873539..9c22dff0c 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -291,12 +291,15 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + aligned_bf_image_guess: np.ndarray, optional + Guess for the reference BF image to cross-correlate against during the first iteration + If None, the incoherent BF image is used instead. + rotation_guess: float, optional + Initial guess of rotation value in degrees + If None, first iteration assumed to be 0 descan_correction_fit_function: str, optional If not None, descan correction will be performed using fit function. One of "constant", "plane", "parabola", or "bezier_two". - rotation_guess: float, optional - Initial guess of defocus value in degrees - If None, first iteration assumed to be 0 plot_average_bf: bool, optional If True, plots the average bright field image, using defocus_guess realspace_mask: np.array, optional @@ -794,7 +797,44 @@ def guess_common_aberrations( scale_arrows=1, **kwargs, ): - """ """ + """ + Generates analytical BF shifts and uses them to align the virtual BF stack, + based on the experimental geometry (rotation, transpose), and common aberrations. + + Parameters + ---------- + rotation_angle_deg: float, optional + Relative rotation between the scan and the diffraction space coordinate systems + transpose: bool, optional + Whether the diffraction intensities are transposed + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma_px: float, optional + KDE gaussian kernel bandwidth in non-upsampled pixels + kde_lowpass_filter: bool, optional + If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function + lanczos_interpolation_order: int, optional + If not None, Lanczos interpolation with the specified order is used instead of bilinear + defocus: float, optional + Defocus value to use in computing analytical BF shifts + astigmatism: float, optional + Astigmatism value to use in computing analytical BF shifts + astigmatism_angle_deg: float, optional + Astigmatism angle to use in computing analytical BF shifts + coma: float, optional + Coma value to use in computing analytical BF shifts + coma_angle_deg: float, optional + Coma angle to use in computing analytical BF shifts + spherical_aberration: float, optional + Spherical aberration value to use in computing analytical BF shifts + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation + plot_arrow_freq: int, optional + Frequency of shifts to plot in quiver plot + scale_arrows: float, optional + Scale to multiply shifts by + + """ xp = self._xp asnumpy = self._asnumpy From ac9e9bf732420f25f26cb61879bb8cf757a23a80 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 12 Aug 2024 16:48:07 -0700 Subject: [PATCH 74/74] adding plot, return flags --- py4DSTEM/process/phase/parallax.py | 58 +++++++++++++++++------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 9c22dff0c..e5768f3cc 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -793,6 +793,8 @@ def guess_common_aberrations( coma_angle_deg=0, spherical_aberration=0, max_batch_size=None, + plot_shifts_and_aligned_bf=True, + return_shifts_and_aligned_bf=False, plot_arrow_freq=1, scale_arrows=1, **kwargs, @@ -829,6 +831,10 @@ def guess_common_aberrations( Spherical aberration value to use in computing analytical BF shifts max_batch_size: int, optional Max number of virtual BF images to use at once in computing cross-correlation + plot_shifts_and_aligned_bf: bool, optional + If True, the analytical shifts and the aligned virtual VF image are plotted + return_shifts_and_aligned_bf: bool, optional + If True, the analytical shifts and the aligned virtual VF image are returned plot_arrow_freq: int, optional Frequency of shifts to plot in quiver plot scale_arrows: float, optional @@ -971,36 +977,40 @@ def guess_common_aberrations( self._crop_padded_object(aligned_stack, upsampled=False) ) - figsize = kwargs.pop("figsize", (8, 4)) - color = kwargs.pop("color", (1, 0, 0, 1)) - cmap = kwargs.pop("cmap", "magma") + if plot_shifts_and_aligned_bf: + figsize = kwargs.pop("figsize", (8, 4)) + color = kwargs.pop("color", (1, 0, 0, 1)) + cmap = kwargs.pop("cmap", "magma") - fig, axs = plt.subplots(1, 2, figsize=figsize) + fig, axs = plt.subplots(1, 2, figsize=figsize) - self.show_shifts( - shifts_ang=shifts_ang, - plot_arrow_freq=plot_arrow_freq, - scale_arrows=scale_arrows, - plot_rotated_shifts=False, - color=color, - figax=(fig, axs[0]), - ) + self.show_shifts( + shifts_ang=shifts_ang, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + plot_rotated_shifts=False, + color=color, + figax=(fig, axs[0]), + ) - axs[0].set_title("Predicted BF Shifts") + axs[0].set_title("Predicted BF Shifts") - extent = [ - 0, - self._scan_sampling[1] * cropped_image.shape[1] / kde_upsample_factor, - self._scan_sampling[0] * cropped_image.shape[0] / kde_upsample_factor, - 0, - ] + extent = [ + 0, + self._scan_sampling[1] * cropped_image.shape[1] / kde_upsample_factor, + self._scan_sampling[0] * cropped_image.shape[0] / kde_upsample_factor, + 0, + ] - axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs) - axs[1].set_ylabel("x [A]") - axs[1].set_xlabel("y [A]") - axs[1].set_title("Predicted Aligned BF Image") + axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs) + axs[1].set_ylabel("x [A]") + axs[1].set_xlabel("y [A]") + axs[1].set_title("Predicted Aligned BF Image") - fig.tight_layout() + fig.tight_layout() + + if return_shifts_and_aligned_bf: + return shifts_ang, cropped_image def reconstruct( self,