From a15b5d949e46dde97a2ce79e5313db1e8a98284a Mon Sep 17 00:00:00 2001 From: cophus Date: Tue, 14 Nov 2023 10:15:09 -0800 Subject: [PATCH 01/64] initial commit for projected potential --- py4DSTEM/process/diffraction/crystal.py | 115 ++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index fb2911992..2d666c3c3 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -876,6 +876,121 @@ def generate_ring_pattern( if return_calc is True: return radii_unique, intensity_unique + + + + def generate_projected_potential( + self, + im_size = (256,256), + pixel_size_Ang = 0.1, + orientation: Optional[Orientation] = None, + ind_orientation: Optional[int] = 0, + orientation_matrix: Optional[np.ndarray] = None, + zone_axis_lattice: Optional[np.ndarray] = None, + proj_x_lattice: Optional[np.ndarray] = None, + zone_axis_cartesian: Optional[np.ndarray] = None, + proj_x_cartesian: Optional[np.ndarray] = None, + ): + """ + Generate a single diffraction pattern, return all peaks as a pointlist. + + Parameters + ---------- + edge_blend: int, optional + Pixels to blend image at the border + + Returns + -------- + orientation (Orientation): an Orientation class object + + Returns: + im_potential: (np.array) + + """ + + # Determine image size in Angstroms + im_size = np.array(im_size) + im_size_Ang = im_size * pixel_size_Ang + + # Parse orientation inputs + if orientation is not None: + if ind_orientation is None: + orientation_matrix = orientation.matrix[0] + else: + orientation_matrix = orientation.matrix[ind_orientation] + elif orientation_matrix is None: + orientation_matrix = self.parse_orientation( + zone_axis_lattice, proj_x_lattice, zone_axis_cartesian, proj_x_cartesian + ) + # projection directions of potential image + proj_x = orientation_matrix[:,0] \ + / np.linalg.norm(orientation_matrix[:,0]) + proj_y = orientation_matrix[:,1] \ + / np.linalg.norm(orientation_matrix[:,1]) + proj_z = orientation_matrix[:,2] \ + / np.linalg.norm(orientation_matrix[:,2]) + + # Determine unit cell axes to tile over + uvw = self.lat_real / \ + np.linalg.norm(self.lat_real, axis = 1) + test = np.abs(uvw @ proj_z) + inds_tile = np.argsort(test)[:2] + m_tile = self.lat_real[inds_tile,:] + + # Determine tiling range + p = np.array([ + [-im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], + [ im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], + [ im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], + [-im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], + ]) + + ab = np.floor(np.linalg.lstsq( + m_tile.T, + p.T, + rcond=None)[0]) + a_range = np.array((np.min(ab[0]),np.max(ab[0]))) + b_range = np.array((np.min(ab[1]),np.max(ab[1]))) + + # Tile unit cell + a_ind, b_ind, atoms_ind = np.meshgrid( + np.arange(a_range[0],a_range[1]), + np.arange(b_range[0],b_range[1]), + np.arange(self.positions.shape[0]), + ) + abc_atoms = self.positions[atoms_ind.ravel(),:] + abc_atoms[:,inds_tile[0]] += a_ind.ravel() + abc_atoms[:,inds_tile[1]] += b_ind.ravel() + xyz_atoms_ang = abc_atoms @ self.lat_real.T + # xyz_atoms_pixels = xyz_atoms_ang / pixel_size_Ang \ + # + im_size/2.0 + + # Project into projected potential image plane + x = proj_x + + # # Lookup table for atomic projected potentials + + # # initialize + im_potential = np.zeros(im_size) + # # im_potential[10:20,10:20] = 1 + + + # Add atoms to potential + + + # test plotting + fig,ax = plt.subplots(figsize = (6,6)) + ax.imshow( + im_potential, + cmap = 'gray', + ) + ax.set_axis_off() + + return im_potential + + + + # Vector conversions and other utilities for Crystal classes def cartesian_to_lattice(self, vec_cartesian): vec_lattice = self.lat_inv @ vec_cartesian From 5b9080d55a11fde3a353979042d48d12c3ff4242 Mon Sep 17 00:00:00 2001 From: cophus Date: Tue, 14 Nov 2023 10:32:33 -0800 Subject: [PATCH 02/64] Atom coordinates mostly working Some tiling issue for image corners --- py4DSTEM/process/diffraction/crystal.py | 35 +++++++++++++++++-------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 2d666c3c3..3a661583e 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -938,19 +938,20 @@ def generate_projected_potential( m_tile = self.lat_real[inds_tile,:] # Determine tiling range - p = np.array([ + p_corners = np.array([ [-im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], [ im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], [ im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], [-im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], ]) - - ab = np.floor(np.linalg.lstsq( + p_corners_proj = p_corners @ \ + np.linalg.inv(np.vstack((proj_x, proj_y, proj_z))) + ab = np.round(np.linalg.lstsq( m_tile.T, - p.T, + p_corners_proj.T, rcond=None)[0]) - a_range = np.array((np.min(ab[0]),np.max(ab[0]))) - b_range = np.array((np.min(ab[1]),np.max(ab[1]))) + a_range = np.array((np.min(ab[0])-1,np.max(ab[0])+1)) + b_range = np.array((np.min(ab[1])-1,np.max(ab[1])+1)) # Tile unit cell a_ind, b_ind, atoms_ind = np.meshgrid( @@ -962,15 +963,25 @@ def generate_projected_potential( abc_atoms[:,inds_tile[0]] += a_ind.ravel() abc_atoms[:,inds_tile[1]] += b_ind.ravel() xyz_atoms_ang = abc_atoms @ self.lat_real.T - # xyz_atoms_pixels = xyz_atoms_ang / pixel_size_Ang \ - # + im_size/2.0 # Project into projected potential image plane - x = proj_x + x = (xyz_atoms_ang @ proj_x) / pixel_size_Ang + im_size[0]/2.0 + y = (xyz_atoms_ang @ proj_y) / pixel_size_Ang + im_size[1]/2.0 + atoms_del = np.logical_or.reduce(( + x < 0, + y < 0, + x > im_size[0], + y > im_size[1], + )) + x = np.delete(x, atoms_del) + y = np.delete(y, atoms_del) + + # Lookup table for atomic projected potentials + + - # # Lookup table for atomic projected potentials - # # initialize + # initialize potential im_potential = np.zeros(im_size) # # im_potential[10:20,10:20] = 1 @@ -984,7 +995,9 @@ def generate_projected_potential( im_potential, cmap = 'gray', ) + ax.scatter(y,x) ax.set_axis_off() + ax.set_aspect('equal') return im_potential From 36ddbcc656c9d1fc006bcc1759ec7290d2c27389 Mon Sep 17 00:00:00 2001 From: cophus Date: Tue, 14 Nov 2023 11:12:11 -0800 Subject: [PATCH 03/64] Projected potential now working with bugs projection algebra definitely has a bug --- py4DSTEM/process/diffraction/crystal.py | 73 ++++++++++++++----- py4DSTEM/process/utils/single_atom_scatter.py | 31 +++++++- 2 files changed, 84 insertions(+), 20 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 3a661583e..eb8642b25 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1,6 +1,7 @@ # Functions for calculating diffraction patterns, matching them to experiments, and creating orientation and phase maps. import numpy as np +from scipy.ndimage import gaussian_filter import matplotlib.pyplot as plt from matplotlib.patches import Circle from fractions import Fraction @@ -883,6 +884,9 @@ def generate_projected_potential( self, im_size = (256,256), pixel_size_Ang = 0.1, + potential_radius_Ang = 3.0, + sigma_image_blur_Ang = 0.1, + plot_result = False, orientation: Optional[Orientation] = None, ind_orientation: Optional[int] = 0, orientation_matrix: Optional[np.ndarray] = None, @@ -963,6 +967,7 @@ def generate_projected_potential( abc_atoms[:,inds_tile[0]] += a_ind.ravel() abc_atoms[:,inds_tile[1]] += b_ind.ravel() xyz_atoms_ang = abc_atoms @ self.lat_real.T + atoms_ID_all = self.numbers[atoms_ind.ravel()] # Project into projected potential image plane x = (xyz_atoms_ang @ proj_x) / pixel_size_Ang + im_size[0]/2.0 @@ -975,35 +980,65 @@ def generate_projected_potential( )) x = np.delete(x, atoms_del) y = np.delete(y, atoms_del) + atoms_ID_all = np.delete(atoms_ID_all, atoms_del) - # Lookup table for atomic projected potentials - - + # Coordinate system for atomic projected potentials + potential_radius = np.ceil(potential_radius_Ang / pixel_size_Ang) + R = np.arange(0.5-potential_radius,potential_radius+0.5) + R_ind = R.astype('int') + R_2D = np.sqrt(R[:,None]**2 + R[None,:]**2) + # Lookup table for atomic projected potentials + atoms_ID = np.unique(self.numbers) + atoms_lookup = np.zeros(( + atoms_ID.shape[0], + R_2D.shape[0], + R_2D.shape[1], + )) + for a0 in range(atoms_ID.shape[0]): + atom_sf = single_atom_scatter([atoms_ID[a0]]) + atoms_lookup[a0,:,:] = atom_sf.projected_potential(atoms_ID[a0], R_2D) # initialize potential im_potential = np.zeros(im_size) - # # im_potential[10:20,10:20] = 1 - - - # Add atoms to potential + # Add atoms to potential image + for a0 in range(atoms_ID_all.shape[0]): + ind = np.argmin(np.abs(atoms_ID - atoms_ID_all[a0])) + + x_ind = np.clip( + np.round(x[a0]).astype('int') + R_ind, + 0, + im_size[0]-1) + y_ind = np.clip( + np.round(y[a0]).astype('int') + R_ind, + 0, + im_size[1]-1) + + im_potential[x_ind[None,:],y_ind[:,None]] += atoms_lookup[ind] + + # if needed, apply gaussian blurring + if sigma_image_blur_Ang > 0: + sigma_image_blur = sigma_image_blur_Ang / pixel_size_Ang + im_potential = gaussian_filter( + im_potential, + sigma_image_blur, + mode = 'nearest', + ) - # test plotting - fig,ax = plt.subplots(figsize = (6,6)) - ax.imshow( - im_potential, - cmap = 'gray', - ) - ax.scatter(y,x) - ax.set_axis_off() - ax.set_aspect('equal') + if plot_result: + # test plotting + fig,ax = plt.subplots(figsize = (6,6)) + ax.imshow( + im_potential, + cmap = 'gray', + ) + # ax.scatter(y,x) + ax.set_axis_off() + ax.set_aspect('equal') return im_potential - - - # Vector conversions and other utilities for Crystal classes def cartesian_to_lattice(self, vec_cartesian): vec_lattice = self.lat_inv @ vec_cartesian diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 8d6e2a891..9085afd93 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -1,6 +1,6 @@ import numpy as np import os - +from scipy.special import kn class single_atom_scatter(object): """ @@ -46,6 +46,35 @@ def electron_scattering_factor(self, Z, gsq, units="A"): elif units == "A": return fe + def projected_potential(self, Z, R): + ai = self.e_scattering_factors[Z - 1, 0:10:2] + bi = self.e_scattering_factors[Z - 1, 1:10:2] + + # Planck's constant in Js + h = 6.62607004e-34 + # Electron rest mass in kg + me = 9.10938356e-31 + # Electron charge in Coulomb + qe = 1.60217662e-19 + # Permittivity of vacuum + eps_0 = 8.85418782e-12 + # Bohr's constant + a_0 = 5.29177210903e-11 + + fe = np.zeros_like(R) + for i in range(5): + # fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 + pre = 2*np.pi/bi[i]**0.5 + fe += (ai[i] / bi[i]**1.5) * \ + (kn(0, pre * R) + R * kn(1, pre * R)) + + # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + return fe * 2 * np.pi**2# / kappa + # if units == "VA": + # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe + # elif units == "A": + # return fe * 2 * np.pi**2 / kappa + def get_scattering_factor( self, elements=None, composition=None, q_coords=None, units=None ): From a0d0f15e60618c962a0667801cdf1ccdc20312d4 Mon Sep 17 00:00:00 2001 From: cophus Date: Tue, 14 Nov 2023 11:14:18 -0800 Subject: [PATCH 04/64] minor tweak --- py4DSTEM/process/diffraction/crystal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index eb8642b25..603e1a6cb 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -966,7 +966,8 @@ def generate_projected_potential( abc_atoms = self.positions[atoms_ind.ravel(),:] abc_atoms[:,inds_tile[0]] += a_ind.ravel() abc_atoms[:,inds_tile[1]] += b_ind.ravel() - xyz_atoms_ang = abc_atoms @ self.lat_real.T + # NOTE - should this be self.lat_real.T? + xyz_atoms_ang = abc_atoms @ self.lat_real atoms_ID_all = self.numbers[atoms_ind.ravel()] # Project into projected potential image plane From 9b726e55bf8d0581a2116cf083997a3fe4ad0282 Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 8 Dec 2023 13:49:03 -0800 Subject: [PATCH 05/64] Projected potentials fixed? --- py4DSTEM/process/diffraction/crystal.py | 52 +++++++++++-------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 603e1a6cb..d1cb04401 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -926,36 +926,31 @@ def generate_projected_potential( orientation_matrix = self.parse_orientation( zone_axis_lattice, proj_x_lattice, zone_axis_cartesian, proj_x_cartesian ) - # projection directions of potential image - proj_x = orientation_matrix[:,0] \ - / np.linalg.norm(orientation_matrix[:,0]) - proj_y = orientation_matrix[:,1] \ - / np.linalg.norm(orientation_matrix[:,1]) - proj_z = orientation_matrix[:,2] \ - / np.linalg.norm(orientation_matrix[:,2]) - - # Determine unit cell axes to tile over - uvw = self.lat_real / \ - np.linalg.norm(self.lat_real, axis = 1) - test = np.abs(uvw @ proj_z) - inds_tile = np.argsort(test)[:2] - m_tile = self.lat_real[inds_tile,:] + + # Rotate unit cell into projection direction + lat_real = self.lat_real.copy() @ orientation_matrix + + # Determine unit cell axes to tile over, by selecting 2/3 with largest in-plane component + inds_tile = np.argsort( + np.linalg.norm(lat_real[:,0:2],axis=1) + )[1:3] + m_tile = lat_real[inds_tile,:] # Determine tiling range p_corners = np.array([ [-im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], [ im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], - [ im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], [-im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], + [ im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], ]) - p_corners_proj = p_corners @ \ - np.linalg.inv(np.vstack((proj_x, proj_y, proj_z))) - ab = np.round(np.linalg.lstsq( - m_tile.T, - p_corners_proj.T, - rcond=None)[0]) - a_range = np.array((np.min(ab[0])-1,np.max(ab[0])+1)) - b_range = np.array((np.min(ab[1])-1,np.max(ab[1])+1)) + ab = np.linalg.lstsq( + m_tile[:,:2].T, + p_corners[:,:2].T, + rcond=None + )[0] + ab = np.floor(ab) + a_range = np.array((np.min(ab[0])-1,np.max(ab[0])+2)) + b_range = np.array((np.min(ab[1])-1,np.max(ab[1])+2)) # Tile unit cell a_ind, b_ind, atoms_ind = np.meshgrid( @@ -966,13 +961,12 @@ def generate_projected_potential( abc_atoms = self.positions[atoms_ind.ravel(),:] abc_atoms[:,inds_tile[0]] += a_ind.ravel() abc_atoms[:,inds_tile[1]] += b_ind.ravel() - # NOTE - should this be self.lat_real.T? - xyz_atoms_ang = abc_atoms @ self.lat_real + xyz_atoms_ang = abc_atoms @ lat_real atoms_ID_all = self.numbers[atoms_ind.ravel()] - # Project into projected potential image plane - x = (xyz_atoms_ang @ proj_x) / pixel_size_Ang + im_size[0]/2.0 - y = (xyz_atoms_ang @ proj_y) / pixel_size_Ang + im_size[1]/2.0 + # Center atoms on image plane + x = xyz_atoms_ang[:,0] / pixel_size_Ang + im_size[0]/2.0 + y = xyz_atoms_ang[:,1] / pixel_size_Ang + im_size[1]/2.0 atoms_del = np.logical_or.reduce(( x < 0, y < 0, @@ -1032,7 +1026,7 @@ def generate_projected_potential( fig,ax = plt.subplots(figsize = (6,6)) ax.imshow( im_potential, - cmap = 'gray', + cmap = 'turbo', ) # ax.scatter(y,x) ax.set_axis_off() From a9b05edc7bd344757479ac662db8cd6866d2638f Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 2 Feb 2024 10:43:53 -0800 Subject: [PATCH 06/64] adding figsize --- py4DSTEM/process/diffraction/crystal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index d1cb04401..7d2606576 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -887,6 +887,7 @@ def generate_projected_potential( potential_radius_Ang = 3.0, sigma_image_blur_Ang = 0.1, plot_result = False, + figsize = (6,6), orientation: Optional[Orientation] = None, ind_orientation: Optional[int] = 0, orientation_matrix: Optional[np.ndarray] = None, @@ -1023,7 +1024,7 @@ def generate_projected_potential( if plot_result: # test plotting - fig,ax = plt.subplots(figsize = (6,6)) + fig,ax = plt.subplots(figsize = figsize) ax.imshow( im_potential, cmap = 'turbo', From 43a396cf7fd5ef7211e3c9431259b77ebc42c429 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 4 Mar 2024 11:33:34 -0800 Subject: [PATCH 07/64] update docstring --- py4DSTEM/process/diffraction/crystal.py | 39 ++++++++++++++++++++----- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 50ddc3087..6f10d8ab1 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -988,19 +988,44 @@ def generate_projected_potential( proj_x_cartesian: Optional[np.ndarray] = None, ): """ - Generate a single diffraction pattern, return all peaks as a pointlist. + Generate an image of the projected potential of crystal in real space, + using cell tiling, and a lookup table of the atomic potentials. Parameters ---------- - edge_blend: int, optional - Pixels to blend image at the border + im_size: tuple, list, np.array + (2,) vector specifying the output size in pixels. + pixel_size_Ang: float + Pixel size in Angstroms. + potential_radius_Ang: float + Radius in Angstroms for how far to integrate the atomic potentials + sigma_image_blur_Ang: float + Image blurring in Angstroms. + plot_result: bool + Plot the projected potential image. + figsize: + (2,) vector giving the size of the output. + + orientation: Orientation + An Orientation class object + ind_orientation: int + If input is an Orientation class object with multiple orientations, + this input can be used to select a specific orientation. + orientation_matrix: array + (3,3) orientation matrix, where columns represent projection directions. + zone_axis_lattice: array + (3,) projection direction in lattice indices + proj_x_lattice: array) + (3,) x-axis direction in lattice indices + zone_axis_cartesian: array + (3,) cartesian projection direction + proj_x_cartesian: array + (3,) cartesian projection direction Returns -------- - orientation (Orientation): an Orientation class object - - Returns: - im_potential: (np.array) + im_potential: (np.array) + Output image of the projected potential. """ From 16c727f7a84e74bada2c9fc118d4de49aea7c6a0 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 4 Mar 2024 13:16:53 -0800 Subject: [PATCH 08/64] Adding thickness projection --- py4DSTEM/process/diffraction/crystal.py | 56 +++++++++++++++---------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 6f10d8ab1..de82f443d 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -974,9 +974,10 @@ def generate_ring_pattern( def generate_projected_potential( self, im_size = (256,256), - pixel_size_Ang = 0.1, - potential_radius_Ang = 3.0, - sigma_image_blur_Ang = 0.1, + pixel_size_angstroms = 0.1, + potential_radius_angstroms = 3.0, + sigma_image_blur_angstroms = 0.1, + thickness_angstroms = 100, plot_result = False, figsize = (6,6), orientation: Optional[Orientation] = None, @@ -995,12 +996,14 @@ def generate_projected_potential( ---------- im_size: tuple, list, np.array (2,) vector specifying the output size in pixels. - pixel_size_Ang: float + pixel_size_angstroms: float Pixel size in Angstroms. - potential_radius_Ang: float + potential_radius_angstroms: float Radius in Angstroms for how far to integrate the atomic potentials - sigma_image_blur_Ang: float + sigma_image_blur_angstroms: float Image blurring in Angstroms. + thickness_angstroms: float + Thickness of the sample in Angstroms. plot_result: bool Plot the projected potential image. figsize: @@ -1031,7 +1034,7 @@ def generate_projected_potential( # Determine image size in Angstroms im_size = np.array(im_size) - im_size_Ang = im_size * pixel_size_Ang + im_size_Ang = im_size * pixel_size_angstroms # Parse orientation inputs if orientation is not None: @@ -1052,6 +1055,8 @@ def generate_projected_potential( np.linalg.norm(lat_real[:,0:2],axis=1) )[1:3] m_tile = lat_real[inds_tile,:] + # Vector projected along zone axis + m_proj = np.squeeze(np.delete(lat_real,inds_tile,axis=0)) # Determine tiling range p_corners = np.array([ @@ -1082,8 +1087,8 @@ def generate_projected_potential( atoms_ID_all = self.numbers[atoms_ind.ravel()] # Center atoms on image plane - x = xyz_atoms_ang[:,0] / pixel_size_Ang + im_size[0]/2.0 - y = xyz_atoms_ang[:,1] / pixel_size_Ang + im_size[1]/2.0 + x = xyz_atoms_ang[:,0] / pixel_size_angstroms + im_size[0]/2.0 + y = xyz_atoms_ang[:,1] / pixel_size_angstroms + im_size[1]/2.0 atoms_del = np.logical_or.reduce(( x < 0, y < 0, @@ -1095,7 +1100,7 @@ def generate_projected_potential( atoms_ID_all = np.delete(atoms_ID_all, atoms_del) # Coordinate system for atomic projected potentials - potential_radius = np.ceil(potential_radius_Ang / pixel_size_Ang) + potential_radius = np.ceil(potential_radius_angstroms / pixel_size_angstroms) R = np.arange(0.5-potential_radius,potential_radius+0.5) R_ind = R.astype('int') R_2D = np.sqrt(R[:,None]**2 + R[None,:]**2) @@ -1111,6 +1116,13 @@ def generate_projected_potential( atom_sf = single_atom_scatter([atoms_ID[a0]]) atoms_lookup[a0,:,:] = atom_sf.projected_potential(atoms_ID[a0], R_2D) + # Projected thickness + num_proj = thickness_angstroms / m_proj[2] + vec_proj = num_proj / pixel_size_angstroms * m_proj[:2] + print(m_proj) + print(vec_proj) + + # initialize potential im_potential = np.zeros(im_size) @@ -1118,20 +1130,22 @@ def generate_projected_potential( for a0 in range(atoms_ID_all.shape[0]): ind = np.argmin(np.abs(atoms_ID - atoms_ID_all[a0])) - x_ind = np.clip( - np.round(x[a0]).astype('int') + R_ind, - 0, - im_size[0]-1) - y_ind = np.clip( - np.round(y[a0]).astype('int') + R_ind, - 0, - im_size[1]-1) + x_ind = np.round(x[a0]).astype('int') + R_ind + y_ind = np.round(y[a0]).astype('int') + R_ind + x_sub = np.logical_and( + x_ind >= 0, + x_ind < im_size[0], + ) + y_sub = np.logical_and( + y_ind >= 0, + y_ind < im_size[1], + ) - im_potential[x_ind[None,:],y_ind[:,None]] += atoms_lookup[ind] + im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] # if needed, apply gaussian blurring - if sigma_image_blur_Ang > 0: - sigma_image_blur = sigma_image_blur_Ang / pixel_size_Ang + if sigma_image_blur_angstroms > 0: + sigma_image_blur = sigma_image_blur_angstroms / pixel_size_angstroms im_potential = gaussian_filter( im_potential, sigma_image_blur, From a1c83c35044bc55a86f8746594f5b1226ec4e9d6 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 4 Mar 2024 14:24:32 -0800 Subject: [PATCH 09/64] Fourier method makes boundary conditions difficult --- py4DSTEM/process/diffraction/crystal.py | 63 +++++++++++++++++++++---- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index de82f443d..ed09203ab 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1055,7 +1055,7 @@ def generate_projected_potential( np.linalg.norm(lat_real[:,0:2],axis=1) )[1:3] m_tile = lat_real[inds_tile,:] - # Vector projected along zone axis + # Vector projected along optic axis m_proj = np.squeeze(np.delete(lat_real,inds_tile,axis=0)) # Determine tiling range @@ -1116,13 +1116,6 @@ def generate_projected_potential( atom_sf = single_atom_scatter([atoms_ID[a0]]) atoms_lookup[a0,:,:] = atom_sf.projected_potential(atoms_ID[a0], R_2D) - # Projected thickness - num_proj = thickness_angstroms / m_proj[2] - vec_proj = num_proj / pixel_size_angstroms * m_proj[:2] - print(m_proj) - print(vec_proj) - - # initialize potential im_potential = np.zeros(im_size) @@ -1142,6 +1135,59 @@ def generate_projected_potential( ) im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] + + # Projected thickness + num_proj = thickness_angstroms / m_proj[2] + vec_proj = num_proj / pixel_size_angstroms * m_proj[:2] + num_points = (np.ceil(np.linalg.norm(vec_proj))*2+1).astype('int') + x = np.linspace(-0.5,0.5,num_points)*vec_proj[0] + im_size[0]/2 + y = np.linspace(-0.5,0.5,num_points)*vec_proj[1] + im_size[1]/2 + xF = np.floor(x).astype('int') + yF = np.floor(y).astype('int') + dx = x - xF + dy = y - yF + im_thickness = np.reshape( + np.bincount( + np.ravel_multi_index((xF ,yF ),im_size,mode='wrap'), + (1-dx)*(1-dy) / num_points, + np.prod(im_size), + ) + \ + np.bincount( + np.ravel_multi_index((xF+1,yF ),im_size,mode='wrap'), + ( dx)*(1-dy) / num_points, + np.prod(im_size), + ) + \ + np.bincount( + np.ravel_multi_index((xF ,yF+1),im_size,mode='wrap'), + (1-dx)*( dy) / num_points, + np.prod(im_size), + ) + \ + np.bincount( + np.ravel_multi_index((xF+1,yF+1),im_size,mode='wrap'), + ( dx)*( dy) / num_points, + np.prod(im_size), + ), + im_size, + ) + # pad images and perform correlation in Fourier space + im_potential = np.real( + np.fft.ifft2( + np.fft.fft2(np.pad(im_potential,((0,im_size[0]),(0,im_size[1])))) \ + * np.fft.fft2(np.pad(im_thickness,((0,im_size[0]),(0,im_size[1])))), + ), + )[im_size[0]//2:im_size[0]+im_size[0]//2,im_size[1]//2:im_size[1]+im_size[1]//2] + + # fig,ax = plt.subplots(figsize=(12,12)) + # ax.imshow( + # np.pad(im_thickness,((0,im_size[0]),(0,im_size[1]))), + # ) + # im_thickness = np.zeros(im_size) + # for a0 in range(num_points): + # x = + + + + # if needed, apply gaussian blurring if sigma_image_blur_angstroms > 0: @@ -1152,6 +1198,7 @@ def generate_projected_potential( mode = 'nearest', ) + if plot_result: # test plotting fig,ax = plt.subplots(figsize = figsize) From c8d661ac49e8c3ba5bfc36a681fe4e3aa34ee5d8 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 4 Mar 2024 15:19:37 -0800 Subject: [PATCH 10/64] Updated plotting --- py4DSTEM/process/diffraction/crystal.py | 123 +++++++++++------------- 1 file changed, 54 insertions(+), 69 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index ed09203ab..8b59c3c92 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -978,6 +978,7 @@ def generate_projected_potential( potential_radius_angstroms = 3.0, sigma_image_blur_angstroms = 0.1, thickness_angstroms = 100, + power_scale = 1.0, plot_result = False, figsize = (6,6), orientation: Optional[Orientation] = None, @@ -1003,7 +1004,10 @@ def generate_projected_potential( sigma_image_blur_angstroms: float Image blurring in Angstroms. thickness_angstroms: float - Thickness of the sample in Angstroms. + Thickness of the sample in Angstroms. + Set thickness_thickness_angstroms = 0 to skip thickness projection. + power_scale: float + Power law scaling of potentials. Set to 2.0 to approximate Z^2 images. plot_result: bool Plot the projected potential image. figsize: @@ -1090,10 +1094,10 @@ def generate_projected_potential( x = xyz_atoms_ang[:,0] / pixel_size_angstroms + im_size[0]/2.0 y = xyz_atoms_ang[:,1] / pixel_size_angstroms + im_size[1]/2.0 atoms_del = np.logical_or.reduce(( - x < 0, - y < 0, - x > im_size[0], - y > im_size[1], + x <= -potential_radius_angstroms/2, + y <= -potential_radius_angstroms/2, + x >= im_size[0] + potential_radius_angstroms/2, + y >= im_size[1] + potential_radius_angstroms/2, )) x = np.delete(x, atoms_del) y = np.delete(y, atoms_del) @@ -1115,6 +1119,15 @@ def generate_projected_potential( for a0 in range(atoms_ID.shape[0]): atom_sf = single_atom_scatter([atoms_ID[a0]]) atoms_lookup[a0,:,:] = atom_sf.projected_potential(atoms_ID[a0], R_2D) + atoms_lookup **= power_scale + + # Thickness + if thickness_angstroms > 0: + thickness_proj = thickness_angstroms / m_proj[2] + vec_proj = thickness_proj / pixel_size_angstroms * m_proj[:2] + num_proj = (np.ceil(np.linalg.norm(vec_proj))+1).astype('int') + x_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[0] + y_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[1] # initialize potential im_potential = np.zeros(im_size) @@ -1123,71 +1136,37 @@ def generate_projected_potential( for a0 in range(atoms_ID_all.shape[0]): ind = np.argmin(np.abs(atoms_ID - atoms_ID_all[a0])) - x_ind = np.round(x[a0]).astype('int') + R_ind - y_ind = np.round(y[a0]).astype('int') + R_ind - x_sub = np.logical_and( - x_ind >= 0, - x_ind < im_size[0], - ) - y_sub = np.logical_and( - y_ind >= 0, - y_ind < im_size[1], - ) + if thickness_angstroms > 0: + for a1 in range(num_proj): + x_ind = np.round(x[a0]+x_proj[a1]).astype('int') + R_ind + y_ind = np.round(y[a0]+y_proj[a1]).astype('int') + R_ind + x_sub = np.logical_and( + x_ind >= 0, + x_ind < im_size[0], + ) + y_sub = np.logical_and( + y_ind >= 0, + y_ind < im_size[1], + ) - im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] + im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] - # Projected thickness - num_proj = thickness_angstroms / m_proj[2] - vec_proj = num_proj / pixel_size_angstroms * m_proj[:2] - num_points = (np.ceil(np.linalg.norm(vec_proj))*2+1).astype('int') - x = np.linspace(-0.5,0.5,num_points)*vec_proj[0] + im_size[0]/2 - y = np.linspace(-0.5,0.5,num_points)*vec_proj[1] + im_size[1]/2 - xF = np.floor(x).astype('int') - yF = np.floor(y).astype('int') - dx = x - xF - dy = y - yF - im_thickness = np.reshape( - np.bincount( - np.ravel_multi_index((xF ,yF ),im_size,mode='wrap'), - (1-dx)*(1-dy) / num_points, - np.prod(im_size), - ) + \ - np.bincount( - np.ravel_multi_index((xF+1,yF ),im_size,mode='wrap'), - ( dx)*(1-dy) / num_points, - np.prod(im_size), - ) + \ - np.bincount( - np.ravel_multi_index((xF ,yF+1),im_size,mode='wrap'), - (1-dx)*( dy) / num_points, - np.prod(im_size), - ) + \ - np.bincount( - np.ravel_multi_index((xF+1,yF+1),im_size,mode='wrap'), - ( dx)*( dy) / num_points, - np.prod(im_size), - ), - im_size, - ) - # pad images and perform correlation in Fourier space - im_potential = np.real( - np.fft.ifft2( - np.fft.fft2(np.pad(im_potential,((0,im_size[0]),(0,im_size[1])))) \ - * np.fft.fft2(np.pad(im_thickness,((0,im_size[0]),(0,im_size[1])))), - ), - )[im_size[0]//2:im_size[0]+im_size[0]//2,im_size[1]//2:im_size[1]+im_size[1]//2] - - # fig,ax = plt.subplots(figsize=(12,12)) - # ax.imshow( - # np.pad(im_thickness,((0,im_size[0]),(0,im_size[1]))), - # ) - # im_thickness = np.zeros(im_size) - # for a0 in range(num_points): - # x = - - + else: + x_ind = np.round(x[a0]).astype('int') + R_ind + y_ind = np.round(y[a0]).astype('int') + R_ind + x_sub = np.logical_and( + x_ind >= 0, + x_ind < im_size[0], + ) + y_sub = np.logical_and( + y_ind >= 0, + y_ind < im_size[1], + ) + im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] + if thickness_angstroms > 0: + im_potential /= num_proj # if needed, apply gaussian blurring if sigma_image_blur_angstroms > 0: @@ -1198,15 +1177,21 @@ def generate_projected_potential( mode = 'nearest', ) - if plot_result: - # test plotting + # quick plotting of the result + int_vals = np.sort(im_potential.ravel()) + int_range = np.array(( + int_vals[np.round(0.02*int_vals.size).astype('int')], + int_vals[np.round(0.98*int_vals.size).astype('int')], + )) + fig,ax = plt.subplots(figsize = figsize) ax.imshow( im_potential, cmap = 'turbo', + vmin = int_range[0], + vmax = int_range[1], ) - # ax.scatter(y,x) ax.set_axis_off() ax.set_aspect('equal') From 3905e8d12c633290e9d552116ca24880cf1175c2 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 7 Mar 2024 21:37:10 -0800 Subject: [PATCH 11/64] Adding robust fitting to ACOM strain mapping --- py4DSTEM/process/diffraction/crystal_ACOM.py | 102 ++++++++++++++----- 1 file changed, 78 insertions(+), 24 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 83dff29e1..f4a309aa8 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -2025,6 +2025,9 @@ def calculate_strain( tol_intensity: float = 1e-4, k_max: Optional[float] = None, min_num_peaks=5, + intensity_weighting = False, + robust = True, + robust_thresh = 3.0, rotation_range=None, mask_from_corr=True, corr_range=(0, 2), @@ -2039,24 +2042,46 @@ def calculate_strain( TODO: add robust fitting? - Args: - bragg_peaks_array (PointListArray): All Bragg peaks - orientation_map (OrientationMap): Orientation map generated from ACOM - corr_kernel_size (float): Correlation kernel size - if user does - not specify, uses self.corr_kernel_size. - sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): Maximum scattering vector - min_num_peaks (int): Minimum number of peaks required. - rotation_range (float): Maximum rotation range in radians (for symmetry reduction). - progress_bar (bool): Show progress bar - mask_from_corr (bool): Use ACOM correlation signal for mask - corr_range (np.ndarray): Range of correlation signals for mask - corr_normalize (bool): Normalize correlation signal before masking + Parameters + ---------- + bragg_peaks_array (PointListArray): + All Bragg peaks + orientation_map (OrientationMap): + Orientation map generated from ACOM + corr_kernel_size (float): + Correlation kernel size - if user does + not specify, uses self.corr_kernel_size. + sigma_excitation_error (float): + sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms + tol_excitation_error_mult (float): + tolerance in units of sigma for s_g inclusion + tol_intensity (np float): + tolerance in intensity units for inclusion of diffraction spots + k_max (float): + Maximum scattering vector + min_num_peaks (int): + Minimum number of peaks required. + intensity_weighting: bool + Set to True to weight least squares by experimental peak intensity. + robust_fitting: bool + Set to True to use robust fitting, which performs outlier rejection. + robust_thresh: float + Threshold for robust fitting weights. + rotation_range (float): + Maximum rotation range in radians (for symmetry reduction). + progress_bar (bool): + Show progress bar + mask_from_corr (bool): + Use ACOM correlation signal for mask + corr_range (np.ndarray): + Range of correlation signals for mask + corr_normalize (bool): + Normalize correlation signal before masking - Returns: - strain_map (RealSlice): strain tensor + Returns + -------- + strain_map (RealSlice): + strain tensor """ @@ -2098,6 +2123,8 @@ def calculate_strain( # Loop over all probe positions for rx, ry in tqdmnd( + # range(220,221), + # range(40,41), *bragg_peaks_array.shape, desc="Calculating strains", unit=" PointList", @@ -2143,16 +2170,43 @@ def calculate_strain( (p_ref.data["qx"][inds_match[keep]], p_ref.data["qy"][inds_match[keep]]) ).T - # Apply intensity weighting from experimental measurements - qxy *= p.data["intensity"][keep, None] - qxy_ref *= p.data["intensity"][keep, None] # Fit transformation matrix # Note - not sure about transpose here # (though it might not matter if rotation isn't included) - m = lstsq(qxy_ref, qxy, rcond=None)[0].T - - # Get the infinitesimal strain matrix + if intensity_weighting: + weights = np.sqrt(p.data["intensity"][keep, None])*0+1 + m = lstsq( + qxy_ref * weights, + qxy * weights, + rcond=None, + )[0].T + else: + m = lstsq( + qxy_ref, + qxy, + rcond=None, + )[0].T + + # Robust fitting + if robust: + for a0 in range(5): + # calculate new weights + qxy_fit = qxy_ref @ m + diff2 = np.sum((qxy_fit - qxy)**2,axis=1) + + weights = np.exp(diff2 / ((-2*robust_thresh**2)*np.median(diff2)))[:,None] + if intensity_weighting: + weights *= np.sqrt(p.data["intensity"][keep, None]) + + # calculate new fits + m = lstsq( + qxy_ref * weights, + qxy * weights, + rcond=None, + )[0].T + + # Set values into the infinitesimal strain matrix strain_map.get_slice("e_xx").data[rx, ry] = 1 - m[0, 0] strain_map.get_slice("e_yy").data[rx, ry] = 1 - m[1, 1] strain_map.get_slice("e_xy").data[rx, ry] = -(m[0, 1] + m[1, 0]) / 2.0 @@ -2160,7 +2214,7 @@ def calculate_strain( # Add finite rotation from ACOM orientation map. # I am not sure about the relative signs here. - # Also, I need to add in the mirror operator. + # Also, maybe I need to add in the mirror operator? if orientation_map.mirror[rx, ry, 0]: strain_map.get_slice("theta").data[rx, ry] += ( orientation_map.angles[rx, ry, 0, 0] From 771ab1f5626c2c8cdba8573050f7379d8c3a06af Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 13 Mar 2024 13:12:48 -0700 Subject: [PATCH 12/64] Updating matching --- py4DSTEM/process/diffraction/crystal_calibrate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_calibrate.py b/py4DSTEM/process/diffraction/crystal_calibrate.py index c068bf79e..b15015c62 100644 --- a/py4DSTEM/process/diffraction/crystal_calibrate.py +++ b/py4DSTEM/process/diffraction/crystal_calibrate.py @@ -21,7 +21,7 @@ def calibrate_pixel_size( k_max=None, k_step=0.002, k_broadening=0.002, - fit_all_intensities=True, + fit_all_intensities=False, set_calibration_in_place=False, verbose=True, plot_result=False, @@ -50,7 +50,7 @@ def calibrate_pixel_size( k_broadening (float): Initial guess for Gaussian broadening of simulated pattern (Å^-1) fit_all_intensities (bool): Set to true to allow all peak intensities to - change independently False forces a single intensity scaling. + change independently. False forces a single intensity scaling for all peaks. set_calibration (bool): if True, set the fit pixel size to the calibration metadata, and calibrate bragg_peaks verbose (bool): Output the calibrated pixel size. @@ -138,7 +138,7 @@ def fit_profile(k, *coefs): if returnfig: fig, ax = self.plot_scattering_intensity( - bragg_peaks=bragg_peaks, + bragg_peaks=bragg_peaks_cali, figsize=figsize, k_broadening=k_broadening, int_power_scale=1.0, @@ -151,7 +151,7 @@ def fit_profile(k, *coefs): ) else: self.plot_scattering_intensity( - bragg_peaks=bragg_peaks, + bragg_peaks=bragg_peaks_cali, figsize=figsize, k_broadening=k_broadening, int_power_scale=1.0, From 938c965535c044cb2adc9e111753137a58417d68 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 20 Mar 2024 15:07:40 -0700 Subject: [PATCH 13/64] Fixing thickness projection in 2D potentials --- py4DSTEM/process/diffraction/crystal.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 8b59c3c92..85f8160b7 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1036,6 +1036,7 @@ def generate_projected_potential( """ + # Determine image size in Angstroms im_size = np.array(im_size) im_size_Ang = im_size * pixel_size_angstroms @@ -1123,11 +1124,15 @@ def generate_projected_potential( # Thickness if thickness_angstroms > 0: - thickness_proj = thickness_angstroms / m_proj[2] - vec_proj = thickness_proj / pixel_size_angstroms * m_proj[:2] - num_proj = (np.ceil(np.linalg.norm(vec_proj))+1).astype('int') - x_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[0] - y_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[1] + num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype('int') + if num_proj > 1: + vec_proj = m_proj[:2] / pixel_size_angstroms + shifts = np.arange(num_proj).astype('float') + shifts -= np.mean(shifts) + x_proj = shifts * vec_proj[0] + y_proj = shifts * vec_proj[1] + else: + num_proj = 1 # initialize potential im_potential = np.zeros(im_size) @@ -1136,7 +1141,7 @@ def generate_projected_potential( for a0 in range(atoms_ID_all.shape[0]): ind = np.argmin(np.abs(atoms_ID - atoms_ID_all[a0])) - if thickness_angstroms > 0: + if num_proj > 1: for a1 in range(num_proj): x_ind = np.round(x[a0]+x_proj[a1]).astype('int') + R_ind y_ind = np.round(y[a0]+y_proj[a1]).astype('int') + R_ind From 6938fe375a60f4e8c99ddee550120673cba433fb Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 20 Mar 2024 15:08:23 -0700 Subject: [PATCH 14/64] black --- py4DSTEM/process/diffraction/crystal.py | 168 +++++++++--------- py4DSTEM/process/diffraction/crystal_ACOM.py | 55 +++--- py4DSTEM/process/utils/single_atom_scatter.py | 10 +- 3 files changed, 118 insertions(+), 115 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 85f8160b7..dc50be154 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -968,19 +968,16 @@ def generate_ring_pattern( if return_calc is True: return radii_unique, intensity_unique - - - def generate_projected_potential( self, - im_size = (256,256), - pixel_size_angstroms = 0.1, - potential_radius_angstroms = 3.0, - sigma_image_blur_angstroms = 0.1, - thickness_angstroms = 100, - power_scale = 1.0, - plot_result = False, - figsize = (6,6), + im_size=(256, 256), + pixel_size_angstroms=0.1, + potential_radius_angstroms=3.0, + sigma_image_blur_angstroms=0.1, + thickness_angstroms=100, + power_scale=1.0, + plot_result=False, + figsize=(6, 6), orientation: Optional[Orientation] = None, ind_orientation: Optional[int] = 0, orientation_matrix: Optional[np.ndarray] = None, @@ -1004,7 +1001,7 @@ def generate_projected_potential( sigma_image_blur_angstroms: float Image blurring in Angstroms. thickness_angstroms: float - Thickness of the sample in Angstroms. + Thickness of the sample in Angstroms. Set thickness_thickness_angstroms = 0 to skip thickness projection. power_scale: float Power law scaling of potentials. Set to 2.0 to approximate Z^2 images. @@ -1012,31 +1009,30 @@ def generate_projected_potential( Plot the projected potential image. figsize: (2,) vector giving the size of the output. - - orientation: Orientation + + orientation: Orientation An Orientation class object ind_orientation: int If input is an Orientation class object with multiple orientations, this input can be used to select a specific orientation. - orientation_matrix: array + orientation_matrix: array (3,3) orientation matrix, where columns represent projection directions. - zone_axis_lattice: array + zone_axis_lattice: array (3,) projection direction in lattice indices - proj_x_lattice: array) + proj_x_lattice: array) (3,) x-axis direction in lattice indices - zone_axis_cartesian: array + zone_axis_cartesian: array (3,) cartesian projection direction - proj_x_cartesian: array + proj_x_cartesian: array (3,) cartesian projection direction Returns -------- - im_potential: (np.array) + im_potential: (np.array) Output image of the projected potential. """ - # Determine image size in Angstroms im_size = np.array(im_size) im_size_Ang = im_size * pixel_size_angstroms @@ -1056,78 +1052,78 @@ def generate_projected_potential( lat_real = self.lat_real.copy() @ orientation_matrix # Determine unit cell axes to tile over, by selecting 2/3 with largest in-plane component - inds_tile = np.argsort( - np.linalg.norm(lat_real[:,0:2],axis=1) - )[1:3] - m_tile = lat_real[inds_tile,:] + inds_tile = np.argsort(np.linalg.norm(lat_real[:, 0:2], axis=1))[1:3] + m_tile = lat_real[inds_tile, :] # Vector projected along optic axis - m_proj = np.squeeze(np.delete(lat_real,inds_tile,axis=0)) + m_proj = np.squeeze(np.delete(lat_real, inds_tile, axis=0)) # Determine tiling range - p_corners = np.array([ - [-im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], - [ im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], - [-im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], - [ im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], - ]) - ab = np.linalg.lstsq( - m_tile[:,:2].T, - p_corners[:,:2].T, - rcond=None - )[0] + p_corners = np.array( + [ + [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + ] + ) + ab = np.linalg.lstsq(m_tile[:, :2].T, p_corners[:, :2].T, rcond=None)[0] ab = np.floor(ab) - a_range = np.array((np.min(ab[0])-1,np.max(ab[0])+2)) - b_range = np.array((np.min(ab[1])-1,np.max(ab[1])+2)) + a_range = np.array((np.min(ab[0]) - 1, np.max(ab[0]) + 2)) + b_range = np.array((np.min(ab[1]) - 1, np.max(ab[1]) + 2)) # Tile unit cell a_ind, b_ind, atoms_ind = np.meshgrid( - np.arange(a_range[0],a_range[1]), - np.arange(b_range[0],b_range[1]), + np.arange(a_range[0], a_range[1]), + np.arange(b_range[0], b_range[1]), np.arange(self.positions.shape[0]), ) - abc_atoms = self.positions[atoms_ind.ravel(),:] - abc_atoms[:,inds_tile[0]] += a_ind.ravel() - abc_atoms[:,inds_tile[1]] += b_ind.ravel() + abc_atoms = self.positions[atoms_ind.ravel(), :] + abc_atoms[:, inds_tile[0]] += a_ind.ravel() + abc_atoms[:, inds_tile[1]] += b_ind.ravel() xyz_atoms_ang = abc_atoms @ lat_real atoms_ID_all = self.numbers[atoms_ind.ravel()] # Center atoms on image plane - x = xyz_atoms_ang[:,0] / pixel_size_angstroms + im_size[0]/2.0 - y = xyz_atoms_ang[:,1] / pixel_size_angstroms + im_size[1]/2.0 - atoms_del = np.logical_or.reduce(( - x <= -potential_radius_angstroms/2, - y <= -potential_radius_angstroms/2, - x >= im_size[0] + potential_radius_angstroms/2, - y >= im_size[1] + potential_radius_angstroms/2, - )) + x = xyz_atoms_ang[:, 0] / pixel_size_angstroms + im_size[0] / 2.0 + y = xyz_atoms_ang[:, 1] / pixel_size_angstroms + im_size[1] / 2.0 + atoms_del = np.logical_or.reduce( + ( + x <= -potential_radius_angstroms / 2, + y <= -potential_radius_angstroms / 2, + x >= im_size[0] + potential_radius_angstroms / 2, + y >= im_size[1] + potential_radius_angstroms / 2, + ) + ) x = np.delete(x, atoms_del) y = np.delete(y, atoms_del) atoms_ID_all = np.delete(atoms_ID_all, atoms_del) # Coordinate system for atomic projected potentials potential_radius = np.ceil(potential_radius_angstroms / pixel_size_angstroms) - R = np.arange(0.5-potential_radius,potential_radius+0.5) - R_ind = R.astype('int') - R_2D = np.sqrt(R[:,None]**2 + R[None,:]**2) + R = np.arange(0.5 - potential_radius, potential_radius + 0.5) + R_ind = R.astype("int") + R_2D = np.sqrt(R[:, None] ** 2 + R[None, :] ** 2) # Lookup table for atomic projected potentials atoms_ID = np.unique(self.numbers) - atoms_lookup = np.zeros(( - atoms_ID.shape[0], - R_2D.shape[0], - R_2D.shape[1], - )) + atoms_lookup = np.zeros( + ( + atoms_ID.shape[0], + R_2D.shape[0], + R_2D.shape[1], + ) + ) for a0 in range(atoms_ID.shape[0]): atom_sf = single_atom_scatter([atoms_ID[a0]]) - atoms_lookup[a0,:,:] = atom_sf.projected_potential(atoms_ID[a0], R_2D) + atoms_lookup[a0, :, :] = atom_sf.projected_potential(atoms_ID[a0], R_2D) atoms_lookup **= power_scale - # Thickness + # Thickness if thickness_angstroms > 0: - num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype('int') + num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype("int") if num_proj > 1: vec_proj = m_proj[:2] / pixel_size_angstroms - shifts = np.arange(num_proj).astype('float') + shifts = np.arange(num_proj).astype("float") shifts -= np.mean(shifts) x_proj = shifts * vec_proj[0] y_proj = shifts * vec_proj[1] @@ -1143,8 +1139,8 @@ def generate_projected_potential( if num_proj > 1: for a1 in range(num_proj): - x_ind = np.round(x[a0]+x_proj[a1]).astype('int') + R_ind - y_ind = np.round(y[a0]+y_proj[a1]).astype('int') + R_ind + x_ind = np.round(x[a0] + x_proj[a1]).astype("int") + R_ind + y_ind = np.round(y[a0] + y_proj[a1]).astype("int") + R_ind x_sub = np.logical_and( x_ind >= 0, x_ind < im_size[0], @@ -1154,11 +1150,13 @@ def generate_projected_potential( y_ind < im_size[1], ) - im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] - + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] + else: - x_ind = np.round(x[a0]).astype('int') + R_ind - y_ind = np.round(y[a0]).astype('int') + R_ind + x_ind = np.round(x[a0]).astype("int") + R_ind + y_ind = np.round(y[a0]).astype("int") + R_ind x_sub = np.logical_and( x_ind >= 0, x_ind < im_size[0], @@ -1168,7 +1166,9 @@ def generate_projected_potential( y_ind < im_size[1], ) - im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] if thickness_angstroms > 0: im_potential /= num_proj @@ -1179,26 +1179,28 @@ def generate_projected_potential( im_potential = gaussian_filter( im_potential, sigma_image_blur, - mode = 'nearest', - ) + mode="nearest", + ) if plot_result: # quick plotting of the result int_vals = np.sort(im_potential.ravel()) - int_range = np.array(( - int_vals[np.round(0.02*int_vals.size).astype('int')], - int_vals[np.round(0.98*int_vals.size).astype('int')], - )) + int_range = np.array( + ( + int_vals[np.round(0.02 * int_vals.size).astype("int")], + int_vals[np.round(0.98 * int_vals.size).astype("int")], + ) + ) - fig,ax = plt.subplots(figsize = figsize) + fig, ax = plt.subplots(figsize=figsize) ax.imshow( im_potential, - cmap = 'turbo', - vmin = int_range[0], - vmax = int_range[1], - ) + cmap="turbo", + vmin=int_range[0], + vmax=int_range[1], + ) ax.set_axis_off() - ax.set_aspect('equal') + ax.set_aspect("equal") return im_potential diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index f4a309aa8..a0f8468ed 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -2025,9 +2025,9 @@ def calculate_strain( tol_intensity: float = 1e-4, k_max: Optional[float] = None, min_num_peaks=5, - intensity_weighting = False, - robust = True, - robust_thresh = 3.0, + intensity_weighting=False, + robust=True, + robust_thresh=3.0, rotation_range=None, mask_from_corr=True, corr_range=(0, 2), @@ -2044,22 +2044,22 @@ def calculate_strain( Parameters ---------- - bragg_peaks_array (PointListArray): + bragg_peaks_array (PointListArray): All Bragg peaks - orientation_map (OrientationMap): + orientation_map (OrientationMap): Orientation map generated from ACOM - corr_kernel_size (float): + corr_kernel_size (float): Correlation kernel size - if user does not specify, uses self.corr_kernel_size. - sigma_excitation_error (float): + sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): + tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): + tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): + k_max (float): Maximum scattering vector - min_num_peaks (int): + min_num_peaks (int): Minimum number of peaks required. intensity_weighting: bool Set to True to weight least squares by experimental peak intensity. @@ -2067,20 +2067,20 @@ def calculate_strain( Set to True to use robust fitting, which performs outlier rejection. robust_thresh: float Threshold for robust fitting weights. - rotation_range (float): + rotation_range (float): Maximum rotation range in radians (for symmetry reduction). - progress_bar (bool): + progress_bar (bool): Show progress bar - mask_from_corr (bool): + mask_from_corr (bool): Use ACOM correlation signal for mask - corr_range (np.ndarray): + corr_range (np.ndarray): Range of correlation signals for mask - corr_normalize (bool): + corr_normalize (bool): Normalize correlation signal before masking Returns -------- - strain_map (RealSlice): + strain_map (RealSlice): strain tensor """ @@ -2170,21 +2170,20 @@ def calculate_strain( (p_ref.data["qx"][inds_match[keep]], p_ref.data["qy"][inds_match[keep]]) ).T - # Fit transformation matrix # Note - not sure about transpose here # (though it might not matter if rotation isn't included) if intensity_weighting: - weights = np.sqrt(p.data["intensity"][keep, None])*0+1 + weights = np.sqrt(p.data["intensity"][keep, None]) * 0 + 1 m = lstsq( - qxy_ref * weights, - qxy * weights, + qxy_ref * weights, + qxy * weights, rcond=None, )[0].T else: m = lstsq( - qxy_ref, - qxy, + qxy_ref, + qxy, rcond=None, )[0].T @@ -2193,16 +2192,18 @@ def calculate_strain( for a0 in range(5): # calculate new weights qxy_fit = qxy_ref @ m - diff2 = np.sum((qxy_fit - qxy)**2,axis=1) + diff2 = np.sum((qxy_fit - qxy) ** 2, axis=1) - weights = np.exp(diff2 / ((-2*robust_thresh**2)*np.median(diff2)))[:,None] + weights = np.exp( + diff2 / ((-2 * robust_thresh**2) * np.median(diff2)) + )[:, None] if intensity_weighting: weights *= np.sqrt(p.data["intensity"][keep, None]) # calculate new fits m = lstsq( - qxy_ref * weights, - qxy * weights, + qxy_ref * weights, + qxy * weights, rcond=None, )[0].T diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 9085afd93..90397560f 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -2,6 +2,7 @@ import os from scipy.special import kn + class single_atom_scatter(object): """ This class calculates the composition averaged single atom scattering factor for a @@ -58,18 +59,17 @@ def projected_potential(self, Z, R): qe = 1.60217662e-19 # Permittivity of vacuum eps_0 = 8.85418782e-12 - # Bohr's constant + # Bohr's constant a_0 = 5.29177210903e-11 fe = np.zeros_like(R) for i in range(5): # fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 - pre = 2*np.pi/bi[i]**0.5 - fe += (ai[i] / bi[i]**1.5) * \ - (kn(0, pre * R) + R * kn(1, pre * R)) + pre = 2 * np.pi / bi[i] ** 0.5 + fe += (ai[i] / bi[i] ** 1.5) * (kn(0, pre * R) + R * kn(1, pre * R)) # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) - return fe * 2 * np.pi**2# / kappa + return fe * 2 * np.pi**2 # / kappa # if units == "VA": # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe # elif units == "A": From 26508318ebc2b39610bab24f56db8767f7167203 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 11:27:45 -0700 Subject: [PATCH 15/64] Remove testing lines. --- py4DSTEM/process/diffraction/crystal_ACOM.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index a0f8468ed..c2bac6424 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -2123,8 +2123,6 @@ def calculate_strain( # Loop over all probe positions for rx, ry in tqdmnd( - # range(220,221), - # range(40,41), *bragg_peaks_array.shape, desc="Calculating strains", unit=" PointList", From 1e700803684514b5bb8589b32600016812b8d0af Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 13:53:44 -0700 Subject: [PATCH 16/64] Trying (and failing) to figure out the potential units --- py4DSTEM/process/diffraction/crystal.py | 3 +++ py4DSTEM/process/utils/single_atom_scatter.py | 21 ++++++++++++------- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index dc50be154..97d596ec4 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -989,6 +989,9 @@ def generate_projected_potential( """ Generate an image of the projected potential of crystal in real space, using cell tiling, and a lookup table of the atomic potentials. + Note that we round atomic positions to the nearest pixel for speed. + + TODO - fix scattering prefactor so that output units are sensible. Parameters ---------- diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 90397560f..54443b68f 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -57,6 +57,8 @@ def projected_potential(self, Z, R): me = 9.10938356e-31 # Electron charge in Coulomb qe = 1.60217662e-19 + # Electron charge in V-Angstroms + # qe = 14.4 # Permittivity of vacuum eps_0 = 8.85418782e-12 # Bohr's constant @@ -64,16 +66,21 @@ def projected_potential(self, Z, R): fe = np.zeros_like(R) for i in range(5): - # fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 pre = 2 * np.pi / bi[i] ** 0.5 fe += (ai[i] / bi[i] ** 1.5) * (kn(0, pre * R) + R * kn(1, pre * R)) - # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) - return fe * 2 * np.pi**2 # / kappa - # if units == "VA": - # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe - # elif units == "A": - # return fe * 2 * np.pi**2 / kappa + # Scale output units + # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*qe) + # fe *= 2*np.pi**2 / kappa + # # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + + # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + # # return fe * 2 * np.pi**2 # / kappa + # # if units == "VA": + # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe + # # elif units == "A": + # # return fe * 2 * np.pi**2 / kappa + return fe def get_scattering_factor( self, elements=None, composition=None, q_coords=None, units=None From c8a81f557d794f0083d9c09bcd8d3d0def52c74f Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 13:54:16 -0700 Subject: [PATCH 17/64] Black formatting --- .../diffraction/WK_scattering_factors.py | 4 ++- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 +++++++++---------- .../process/phase/magnetic_ptychography.py | 28 +++++++++---------- py4DSTEM/process/phase/parallax.py | 24 ++++++++-------- .../process/phase/ptychographic_methods.py | 4 ++- .../process/phase/ptychographic_tomography.py | 28 +++++++++---------- py4DSTEM/process/phase/utils.py | 4 ++- py4DSTEM/process/utils/utils.py | 18 ++++++++++-- 10 files changed, 82 insertions(+), 62 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index eb964de96..70110a977 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,7 +221,9 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( + BJ / (BI + BJ) + ) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 47df2e6ca..da016b3ed 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,7 +454,8 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) + * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 5c2d56a3c..9973ff79f 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,7 +169,8 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index c9efae806..0a58a514b 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1188,20 +1188,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 2e887739f..67e315234 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1490,20 +1490,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index aa9323900..79392f6e6 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2195,16 +2195,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 0] - ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 1] - ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2215,16 +2215,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 0] - ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 1] - ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index fa0b1db9f..cb109073e 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -351,7 +351,9 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index f3b2991ab..8b4d4d984 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1080,20 +1080,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index a5a541795..f53637184 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,7 +203,9 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) + return xp.exp( + -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 + ) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index ddeeb2c36..60da616d1 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,7 +93,12 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) return lam @@ -102,8 +107,15 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 - sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) + sigma = ( + (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + ) return sigma From e362b6b197cfd75c12da30a70d27e719704a4c0e Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Mar 2024 14:03:20 -0700 Subject: [PATCH 18/64] Black again --- .../diffraction/WK_scattering_factors.py | 4 +-- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 +++++++++---------- .../process/phase/magnetic_ptychography.py | 28 +++++++++---------- py4DSTEM/process/phase/parallax.py | 24 ++++++++-------- .../process/phase/ptychographic_methods.py | 4 +-- .../process/phase/ptychographic_tomography.py | 28 +++++++++---------- py4DSTEM/process/phase/utils.py | 4 +-- py4DSTEM/process/utils/utils.py | 18 ++---------- 10 files changed, 62 insertions(+), 82 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index 70110a977..eb964de96 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,9 +221,7 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( - BJ / (BI + BJ) - ) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index da016b3ed..47df2e6ca 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,8 +454,7 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) - * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 9973ff79f..5c2d56a3c 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,8 +169,7 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) - + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 0a58a514b..c9efae806 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1188,20 +1188,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 67e315234..2e887739f 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1490,20 +1490,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index d366a01df..060a151aa 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2195,16 +2195,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 0] + measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 0] + ) measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 1] + measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 1] + ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2215,16 +2215,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 0] + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 0] + ) fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 1] + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 1] + ) max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index cb109073e..fa0b1db9f 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -351,9 +351,7 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) + propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 8b4d4d984..f3b2991ab 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1080,20 +1080,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index f53637184..a5a541795 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,9 +203,7 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp( - -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 - ) + return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 60da616d1..ddeeb2c36 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,12 +93,7 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 return lam @@ -107,15 +102,8 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) - sigma = ( - (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) return sigma From 0d772ad91a5fbfbe5ecc4e512b1a3e829288ceaf Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 8 Apr 2024 21:21:28 -0700 Subject: [PATCH 19/64] updating phase mapping --- py4DSTEM/process/diffraction/crystal_phase.py | 1353 ++++++++++++++--- 1 file changed, 1122 insertions(+), 231 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index d28616aa9..508a01c79 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -3,12 +3,14 @@ from scipy.optimize import nnls import matplotlib as mpl import matplotlib.pyplot as plt +from scipy.ndimage import gaussian_filter from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern + class Crystal_Phase: """ A class storing multiple crystal structures, and associated diffraction data. @@ -19,8 +21,9 @@ class Crystal_Phase: def __init__( self, crystals, - orientation_maps, - name, + crystal_names = None, + orientation_maps = None, + name = None, ): """ Args: @@ -33,76 +36,964 @@ def __init__( self.num_crystals = len(crystals) else: raise TypeError("crystals must be a list of crystal instances.") - if isinstance(orientation_maps, list): + + # List of orientation maps + if orientation_maps is None: + self.orientation_maps = [crystals[ind].orientation_map for ind in range(self.num_crystals)] + else: if len(self.crystals) != len(orientation_maps): raise ValueError( "Orientation maps must have the same number of entries as crystals." ) self.orientation_maps = orientation_maps + + # Names of all crystal phases + if crystal_names is None: + self.crystal_names = ['crystal' + str(ind) for ind in range(self.num_crystals)] else: - raise TypeError("orientation_maps must be a list of orientation maps.") - self.name = name - return + self.crystal_names = crystal_names + + # Name of the phase map + if name is None: + self.name = 'phase map' + else: + self.name = name + + # Get some attributes from crystals + self.k_max = np.zeros(self.num_crystals) + self.num_matches = np.zeros(self.num_crystals, dtype='int') + self.crystal_identity = np.zeros((0,2), dtype='int') + for a0 in range(self.num_crystals): + self.k_max[a0] = self.crystals[a0].k_max + self.num_matches[a0] = self.crystals[a0].orientation_map.num_matches + for a1 in range(self.num_matches[a0]): + self.crystal_identity = np.append(self.crystal_identity,np.array((a0,a1),dtype='int')[None,:], axis=0) - def plot_all_phase_maps(self, map_scale_values=None, index=0): + self.num_fits = np.sum(self.num_matches) + + + + def quantify_single_pattern( + self, + pointlistarray: PointListArray, + xy_position = (0,0), + corr_kernel_size = 0.04, + allow_strain = True, + # corr_distance_scale = 1.0, + include_false_positives = True, + weight_false_positives = 1.0, + max_number_phases = 1, + sigma_excitation_error = 0.04, + power_experiment = 0.25, + power_calculated = 0.25, + plot_result = True, + plot_only_nonzero_phases = True, + plot_unmatched_peaks = False, + plot_correlation_radius = False, + scale_markers_experiment = 40, + scale_markers_calculated = 500, + crystal_inds_plot = None, + phase_colors = None, + figsize = (10,7), + verbose = True, + returnfig = False, + ): + """ + Quantify the phase for a single diffraction pattern. """ - Visualize phase maps of dataset. - Args: - map_scale_values (float): Value to scale correlations by + # tolerance + tol2 = 4e-4 + + # calibrations + center = pointlistarray.calstate['center'] + ellipse = pointlistarray.calstate['ellipse'] + pixel = pointlistarray.calstate['pixel'] + rotate = pointlistarray.calstate['rotate'] + if center is False: + raise ValueError('Bragg peaks must be center calibration') + if pixel is False: + raise ValueError('Bragg peaks must have pixel size calibration') + # TODO - potentially warn the user if ellipse / rotate calibration not available + + if phase_colors is None: + phase_colors = np.array(( + (1.0,0.0,0.0,1.0), + (0.0,0.8,1.0,1.0), + (0.0,0.6,0.0,1.0), + (1.0,0.0,1.0,1.0), + (0.0,0.2,1.0,1.0), + (1.0,0.8,0.0,1.0), + )) + + # Experimental values + bragg_peaks = pointlistarray.get_vectors( + xy_position[0], + xy_position[1], + center = center, + ellipse = ellipse, + pixel = pixel, + rotate = rotate) + # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() + keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + # ind_center_beam = np.argmin( + # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) + # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') + # mask[ind_center_beam] = False + # bragg_peaks.remove(ind_center_beam) + qx = bragg_peaks.data["qx"][keep] + qy = bragg_peaks.data["qy"][keep] + qx0 = bragg_peaks.data["qx"][np.logical_not(keep)] + qy0 = bragg_peaks.data["qy"][np.logical_not(keep)] + if power_experiment == 0: + intensity = np.ones_like(qx) + intensity0 = np.ones_like(qx0) + else: + intensity = bragg_peaks.data["intensity"][keep]**power_experiment + intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_experiment + int_total = np.sum(intensity) + + # init basis array + if include_false_positives: + basis = np.zeros((intensity.shape[0], self.num_fits)) + unpaired_peaks = [] + else: + basis = np.zeros((intensity.shape[0], self.num_fits)) + if allow_strain: + m_strains = np.zeros((self.num_fits,2,2)) + + # kernel radius squared + radius_max_2 = corr_kernel_size**2 + + # init for plotting + if plot_result: + library_peaks = [] + library_int = [] + library_matches = [] + + # Generate point list data, match to experimental peaks + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + # for c in range(self.num_crystals): + # for m in range(self.num_matches[c]): + # ind_match += 1 + + # Generate simulated peaks + bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + self.crystals[c].orientation_map.get_orientation( + xy_position[0], xy_position[1] + ), + ind_orientation = m, + sigma_excitation_error = sigma_excitation_error, + ) + del_peak = bragg_peaks_fit.data["qx"]**2 \ + + bragg_peaks_fit.data["qy"]**2 < tol2 + bragg_peaks_fit.remove(del_peak) + + # peak intensities + if power_calculated == 0: + int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) + else: + int_fit = bragg_peaks_fit.data['intensity']**power_calculated + + # Pair peaks to experiment + if plot_result: + matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') + + if allow_strain: + # Initial peak pairing to find best-fit strain distortion + pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') + pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + pair_sub[a1] = True + pair_inds[a1] = ind_min + + # calculate best-fit strain tensor, weighted by the intensities. + # requires at least 4 peak pairs + if np.sum(pair_sub) >= 4: + pair_basis = np.vstack(( + qx[pair_inds[pair_sub]], + qy[pair_inds[pair_sub]], + )).T + pair_obs = np.vstack(( + bragg_peaks_fit.data['qx'][pair_sub], + bragg_peaks_fit.data['qy'][pair_sub], + )).T + + # weights + dists = np.sqrt( + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) + weights = np.sqrt( + int_fit[pair_sub] * intensity[pair_inds[pair_sub]] + ) * (1 - dists / corr_kernel_size) + + # wtrain tensor + m_strain = np.linalg.lstsq( + pair_basis * weights[:,None], + pair_obs * weights[:,None], + rcond = None, + )[0] + m_strains[a0] = m_strain + + # Transformed peak positions + qx_copy = qx.copy() + qy_copy = qy.copy() + qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) + # print(np.round(dist,4)) + + # Loop over all peaks, pair experiment to library + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + # weight = 1 + corr_distance_scale * \ + # np.sqrt(dist2[ind_min]) / corr_kernel_size + # basis[ind_min,a0] = weight * int_fit[a1] + basis[ind_min,a0] = int_fit[a1] + if plot_result: + matches[a1] = True + elif include_false_positives: + # unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) + unpaired_peaks.append([a0,int_fit[a1]]) + + if plot_result: + library_peaks.append(bragg_peaks_fit) + library_int.append(int_fit) + library_matches.append(matches) + + # If needed, augment basis and observations with false positives + if include_false_positives: + basis_aug = np.zeros((len(unpaired_peaks),self.num_fits)) + for a0 in range(len(unpaired_peaks)): + basis_aug[a0,unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] + + basis = np.vstack((basis, basis_aug * weight_false_positives)) + obs = np.hstack((intensity, np.zeros(len(unpaired_peaks)))) + + else: + obs = intensity + + # Solve for phase coefficients + try: + phase_weights = np.zeros(self.num_fits) + inds_solve = np.ones(self.num_fits,dtype='bool') + + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_phases: + phase_weights[inds_solve] = phase_weights_cand + phase_residual = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + + except: + phase_weights = np.zeros(self.num_fits) + phase_residual = np.sqrt(np.sum(intensity**2)) + + if verbose: + ind_max = np.argmax(phase_weights) + # print() + print('\033[1m' + 'phase_weight or_ind name' + '\033[0m') + # print() + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + line = '{:>12} {:>8} {:<12}'.format( + np.round(phase_weights[a0],decimals=2), + m, + self.crystal_names[c] + ) + if a0 == ind_max: + print('\033[1m' + line + '\033[0m') + else: + print(line) + # print() + + # Plotting + if plot_result: + # fig, ax = plt.subplots(figsize=figsize) + fig = plt.figure(figsize=figsize) + # if plot_layout == 0: + # ax_x = fig.add_axes( + # [0.0+figbound[0], 0.0, 0.4-2*+figbound[0], 1.0]) + ax = fig.add_axes([0.0, 0.0, 0.66, 1.0]) + ax_leg = fig.add_axes([0.68, 0.0, 0.3, 1.0]) + + if plot_correlation_radius: + # plot the experimental radii + t = np.linspace(0,2*np.pi,91,endpoint=True) + ct = np.cos(t) * corr_kernel_size + st = np.sin(t) * corr_kernel_size + for a0 in range(qx.shape[0]): + ax.plot( + qy[a0] + st, + qx[a0] + ct, + color = 'k', + linewidth = 1, + ) + + # plot the experimental peaks + ax.scatter( + qy0, + qx0, + s = scale_markers_experiment * intensity0, + marker = "o", + facecolor = [0.7, 0.7, 0.7], + ) + ax.scatter( + qy, + qx, + s = scale_markers_experiment * intensity, + marker = "o", + facecolor = [0.7, 0.7, 0.7], + ) + # legend + k_max = np.max(self.k_max) + dx_leg = -0.05*k_max + dy_leg = 0.04*k_max + text_params = { + "va": "center", + "ha": "left", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 14, + } + if plot_correlation_radius: + ax_leg.plot( + 0 + st*0.5, + -dx_leg + ct*0.5, + color = 'k', + linewidth = 1, + ) + ax_leg.scatter( + 0, + 0, + s = 200, + marker = "o", + facecolor = [0.7, 0.7, 0.7], + ) + ax_leg.text( + dy_leg, + 0, + 'Experimental peaks', + **text_params) + if plot_correlation_radius: + ax_leg.text( + dy_leg, + -dx_leg, + 'Correlation radius', + **text_params) + + + # plot calculated diffraction patterns + uvals = phase_colors.copy() + uvals[:,3] = 0.3 + # uvals = np.array(( + # (1.0,0.0,0.0,0.2), + # (0.0,0.8,1.0,0.2), + # (0.0,0.6,0.0,0.2), + # (1.0,0.0,1.0,0.2), + # (0.0,0.2,1.0,0.2), + # (1.0,0.8,0.0,0.2), + # )) + mvals = ['v','^','<','>','d','s',] + + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + + if crystal_inds_plot == None or np.min(np.abs(c - crystal_inds_plot)) == 0: + + qx_fit = library_peaks[a0].data['qx'] + qy_fit = library_peaks[a0].data['qy'] + + # if allow_strain: + # m_strain = m_strains[a0] + # # Transformed peak positions + # qx_copy = qx_fit.copy() + # qy_copy = qy_fit.copy() + # qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + # qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + int_fit = library_int[a0] + matches_fit = library_matches[a0] + + if plot_only_nonzero_phases is False or phase_weights[a0] > 0: + + if np.mod(m,2) == 0: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + if plot_unmatched_peaks: + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + else: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # facecolors = (1,1,1,0.5), + linewidth = 2, + ) + if plot_unmatched_peaks: + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (1,1,1,0.5), + linewidth = 2, + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # facecolors = (1,1,1,0.5), + ) + + # legend text + ax_leg.text( + dy_leg, + (a0+1)*dx_leg, + self.crystal_names[c], + **text_params) + + + # appearance + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + + ax_leg.set_xlim((-0.1*k_max, 0.4*k_max)) + ax_leg.set_ylim((-0.5*k_max, 0.5*k_max)) + ax_leg.set_axis_off() + + if returnfig: + return phase_weights, phase_residual, int_total, fig, ax + else: + return phase_weights, phase_residual, int_total + + def quantify_phase( + self, + pointlistarray: PointListArray, + corr_kernel_size = 0.04, + # corr_distance_scale = 1.0, + allow_strain = True, + include_false_positives = True, + weight_false_positives = 1.0, + max_number_phases = 2, + sigma_excitation_error = 0.02, + power_experiment = 0.25, + power_calculated = 0.25, + progress_bar = True, + ): """ - phase_maps = [] - if map_scale_values is None: - map_scale_values = [1] * len(self.orientation_maps) - corr_sum = np.sum( - [ - (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m]) - for m in range(len(self.orientation_maps)) - ] - ) - for m in range(len(self.orientation_maps)): - phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum) - show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno") - return - - def plot_phase_map(self, index=0, cmap=None): - corr_array = np.dstack( - [maps.corr[:, :, index] for maps in self.orientation_maps] + Quantify phase of all diffraction patterns. + """ + + # init results arrays + self.phase_weights = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + self.num_fits, + )) + self.phase_residuals = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) + self.int_total = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) + + for rx, ry in tqdmnd( + *pointlistarray.shape, + desc="Matching Orientations", + unit=" PointList", + disable=not progress_bar, + ): + # calculate phase weights + phase_weights, phase_residual, int_peaks = self.quantify_single_pattern( + pointlistarray = pointlistarray, + xy_position = (rx,ry), + corr_kernel_size = corr_kernel_size, + allow_strain = allow_strain, + # corr_distance_scale = corr_distance_scale, + include_false_positives = include_false_positives, + weight_false_positives = weight_false_positives, + max_number_phases = max_number_phases, + sigma_excitation_error = sigma_excitation_error, + power_experiment = power_experiment, + power_calculated = power_calculated, + plot_result = False, + verbose = False, + returnfig = False, + ) + self.phase_weights[rx,ry] = phase_weights + self.phase_residuals[rx,ry] = phase_residual + self.int_total[rx,ry] = int_peaks + + + def plot_phase_weights( + self, + weight_range = (0.0,1.0), + weight_normalize = False, + total_intensity_normalize = True, + cmap = 'gray', + show_ticks = False, + show_axes = True, + layout = 0, + figsize = (6,6), + returnfig = False, + ): + """ + Plot the individual phase weight maps and residuals. + """ + + # Normalization if required to total DF peak intensity + phase_weights = self.phase_weights.copy() + phase_residuals = self.phase_residuals.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:,:,a0][sub] /= self.int_total[sub] + phase_residuals[sub] /= self.int_total[sub] + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(phase_weights,axis=2)) + else: + scale = 1 + weight_range = np.array(weight_range) * scale + + # plotting + if layout == 0: + fig,ax = plt.subplots( + 1, + self.num_crystals + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + elif layout == 1: + fig,ax = plt.subplots( + self.num_crystals + 1, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + im = np.sum(phase_weights[:,:,sub],axis=2) + im = np.clip( + (im - weight_range[0]) / (weight_range[1] - weight_range[0]), + 0,1) + ax[a0].imshow( + im, + vmin = 0, + vmax = 1, + cmap = cmap, + ) + ax[a0].set_title( + self.crystal_names[a0], + fontsize = 16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # plot residuals + im = np.clip( + (phase_residuals - weight_range[0]) \ + / (weight_range[1] - weight_range[0]), + 0,1) + ax[self.num_crystals].imshow( + im, + vmin = 0, + vmax = 1, + cmap = cmap, ) - best_corr_score = np.max(corr_array, axis=2) - best_match_phase = [ - np.where(corr_array[:, :, p] == best_corr_score, True, False) - for p in range(len(self.orientation_maps)) - ] - - if cmap is None: - cm = plt.get_cmap("rainbow") - cmap = [ - cm(1.0 * i / len(self.orientation_maps)) - for i in range(len(self.orientation_maps)) - ] - - fig, (ax) = plt.subplots(figsize=(6, 6)) - ax.matshow( - np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), - cmap="gray", + ax[self.num_crystals].set_title( + 'Residuals', + fontsize = 16, ) - ax.axis("off") - - for m in range(len(self.orientation_maps)): - c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[ - m - ] - cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10) - ax.matshow( - np.ma.array( - self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m] - ), - cmap=cm, + if not show_ticks: + ax[self.num_crystals].set_xticks([]) + ax[self.num_crystals].set_yticks([]) + if not show_axes: + ax[self.num_crystals].set_axis_off() + + if returnfig: + return fig, ax + + + def plot_phase_maps( + self, + weight_threshold = 0.5, + weight_normalize = True, + total_intensity_normalize = True, + plot_combine = False, + crystal_inds_plot = None, + phase_colors = None, + show_ticks = False, + show_axes = True, + layout = 0, + figsize = (6,6), + return_phase_estimate = False, + return_rgb_images = False, + returnfig = False, + ): + """ + Plot the individual phase weight maps and residuals. + """ + + if phase_colors is None: + phase_colors = np.array(( + (1.0,0.0,0.0), + (0.0,0.8,1.0), + (0.0,0.8,0.0), + (1.0,0.0,1.0), + (0.0,0.4,1.0), + (1.0,0.8,0.0), + )) + + phase_weights = self.phase_weights.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:,:,a0][sub] /= self.int_total[sub] + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(phase_weights,axis=2)) + else: + scale = 1 + weight_threshold = weight_threshold * scale + + # init + im_all = np.zeros(( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1])) + im_rgb_all = np.zeros(( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + 3)) + + # phase weights over threshold + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + im = np.sum(phase_weights[:,:,sub],axis=2) + im_all[a0] = np.maximum(im - weight_threshold, 0) + + # estimate compositions + im_sum = np.sum(im_all, axis = 0) + sub = im_sum > 0.0 + for a0 in range(self.num_crystals): + im_all[a0][sub] /= im_sum[sub] + + for a1 in range(3): + im_rgb_all[a0,:,:,a1] = im_all[a0] * phase_colors[a0,a1] + + if plot_combine: + if crystal_inds_plot is None: + im_rgb = np.sum(im_rgb_all, axis = 0) + else: + im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis = 0) + + im_rgb = np.clip(im_rgb,0,1) + + fig,ax = plt.subplots(1,1,figsize=figsize) + ax.imshow( + im_rgb, + ) + ax.set_title( + 'Phase Maps', + fontsize = 16, ) - plt.show() + if not show_ticks: + ax.set_xticks([]) + ax.set_yticks([]) + if not show_axes: + ax.set_axis_off() + + else: + # plotting + if layout == 0: + fig,ax = plt.subplots( + 1, + self.num_crystals, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + elif layout == 1: + fig,ax = plt.subplots( + self.num_crystals, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + + for a0 in range(self.num_crystals): + + ax[a0].imshow( + im_rgb_all[a0], + ) + ax[a0].set_title( + self.crystal_names[a0], + fontsize = 16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # All possible returns + if return_phase_estimate: + if returnfig: + return im_all, fig, ax + else: + return im_all + elif return_rgb_images: + if plot_combine: + if returnfig: + return im_rgb, fig, ax + else: + return im_rgb + else: + if returnfig: + return im_rgb_all, fig, ax + else: + return im_rgb_all + else: + if returnfig: + return fig, ax + + + def plot_dominant_phase( + self, + rel_range = (0.0,1.0), + sigma = 0.0, + phase_colors = None, + figsize = (6,6), + ): + """ + Plot a combined figure showing the primary phase at each probe position. + Mask by the reliability index (best match minus 2nd best match). + """ + + if phase_colors is None: + phase_colors = np.array([ + [1.0,0.9,0.6], + [1,0,0], + [0,0.7,0], + [0,0.7,1], + [1,0,1], + ]) + + + # init arrays + scan_shape = self.phase_weights.shape[:2] + phase_map = np.zeros(scan_shape) + phase_corr = np.zeros(scan_shape) + phase_corr_2nd = np.zeros(scan_shape) + phase_sig = np.zeros((self.num_crystals,scan_shape[0],scan_shape[1])) + + # sum up phase weights by crystal type + for a0 in range(self.num_fits): + ind = self.crystal_identity[a0,0] + phase_sig[ind] += self.phase_weights[:,:,a0] + + # smoothing of the outputs + if sigma > 0.0: + for a0 in range(self.num_crystals): + phase_sig[a0] = gaussian_filter( + phase_sig[a0], + sigma = sigma, + mode = 'nearest', + ) + + + # find highest correlation score for each crystal and match index + for a0 in range(self.num_crystals): + sub = phase_sig[a0] > phase_corr + phase_map[sub] = a0 + phase_corr[sub] = phase_sig[a0][sub] + + # find the second correlation score for each crystal and match index + for a0 in range(self.num_crystals): + corr = phase_sig[a0].copy() + corr[phase_map==a0] = 0.0 + sub = corr > phase_corr_2nd + phase_corr_2nd[sub] = corr[sub] + + + + + + # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): + # sub = crystals[a0].orientation_map.corr[:,:,a1] > phase_corr + # phase_map[sub] = a0 + # phase_corr[sub] = crystals[a0].orientation_map.corr[:,:,a1][sub] + + # # find the second correlation score for each crystal and match index + # for a0 in range(len(crystals)): + # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): + # corr = crystals[a0].orientation_map.corr[:,:,a1].copy() + # corr[phase_map==a0] = 0.0 + # sub = corr > phase_corr_2nd + # phase_corr_2nd[sub] = corr[sub] + + # Estimate the reliability + phase_rel = phase_corr - phase_corr_2nd + + # # Generate a color plotting image + # c_vals = np.array([ + # [0.7,0.5,0.3], + # [1,0,0], + # [0,0.7,0], + # [0,0.7,1], + # [0,0.7,1], + # ]) + phase_scale = np.clip( + (phase_rel - rel_range[0]) / (rel_range[1] - rel_range[0]), + 0, + 1) + + + phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) + for a0 in range(self.num_crystals): + sub = phase_map==a0 + for a1 in range(3): + phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] + # normalize + # phase_rgb = np.clip( + # (phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), + # 0,1) + + + + fig,ax = plt.subplots(figsize=figsize) + ax.imshow( + phase_rgb, + # vmin = 0, + # vmax = 5, + # phase_rgb, + # phase_corr - phase_corr_2nd, + # cmap = 'turbo', + # vmin = 0, + # vmax = 3, + # cmap = 'gray', + ) + + return fig,ax + + + # def plot_all_phase_maps(self, map_scale_values=None, index=0): + # """ + # Visualize phase maps of dataset. + + # Args: + # map_scale_values (float): Value to scale correlations by + # """ + # phase_maps = [] + # if map_scale_values is None: + # map_scale_values = [1] * len(self.orientation_maps) + # corr_sum = np.sum( + # [ + # (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m]) + # for m in range(len(self.orientation_maps)) + # ] + # ) + # for m in range(len(self.orientation_maps)): + # phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum) + # show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno") + # return + + # def plot_phase_map(self, index=0, cmap=None): + # corr_array = np.dstack( + # [maps.corr[:, :, index] for maps in self.orientation_maps] + # ) + # best_corr_score = np.max(corr_array, axis=2) + # best_match_phase = [ + # np.where(corr_array[:, :, p] == best_corr_score, True, False) + # for p in range(len(self.orientation_maps)) + # ] - return + # if cmap is None: + # cm = plt.get_cmap("rainbow") + # cmap = [ + # cm(1.0 * i / len(self.orientation_maps)) + # for i in range(len(self.orientation_maps)) + # ] + + # fig, (ax) = plt.subplots(figsize=(6, 6)) + # ax.matshow( + # np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), + # cmap="gray", + # ) + # ax.axis("off") + + # for m in range(len(self.orientation_maps)): + # c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[ + # m + # ] + # cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10) + # ax.matshow( + # np.ma.array( + # self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m] + # ), + # cmap=cm, + # ) + # plt.show() + + # return # Potentially introduce a way to check best match out of all orientations in phase plan and plug into model # to quantify phase @@ -124,187 +1015,187 @@ def plot_phase_map(self, index=0, cmap=None): # ): # return - def quantify_phase( - self, - pointlistarray, - tolerance_distance=0.08, - method="nnls", - intensity_power=0, - mask_peaks=None, - ): - """ - Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. + # def quantify_phase( + # self, + # pointlistarray, + # tolerance_distance=0.08, + # method="nnls", + # intensity_power=0, + # mask_peaks=None, + # ): + # """ + # Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from + # Args: + # pointlisarray (pointlistarray): Pointlistarray to quantify phase of + # tolerance_distance (float): Distance allowed between a peak and match + # method (str): Numerical method used to quantify phase + # intensity_power (float): ... + # mask_peaks (list, optional): A pointer of which positions to mask peaks from - Details: - """ - if isinstance(pointlistarray, PointListArray): - phase_weights = np.zeros( - ( - pointlistarray.shape[0], - pointlistarray.shape[1], - np.sum([map.num_matches for map in self.orientation_maps]), - ) - ) - phase_residuals = np.zeros(pointlistarray.shape) - for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): - ( - _, - phase_weight, - phase_residual, - crystal_identity, - ) = self.quantify_phase_pointlist( - pointlistarray, - position=[Rx, Ry], - tolerance_distance=tolerance_distance, - method=method, - intensity_power=intensity_power, - mask_peaks=mask_peaks, - ) - phase_weights[Rx, Ry, :] = phase_weight - phase_residuals[Rx, Ry] = phase_residual - self.phase_weights = phase_weights - self.phase_residuals = phase_residuals - self.crystal_identity = crystal_identity - return - else: - return TypeError("pointlistarray must be of type pointlistarray.") - return + # Details: + # """ + # if isinstance(pointlistarray, PointListArray): + # phase_weights = np.zeros( + # ( + # pointlistarray.shape[0], + # pointlistarray.shape[1], + # np.sum([map.num_matches for map in self.orientation_maps]), + # ) + # ) + # phase_residuals = np.zeros(pointlistarray.shape) + # for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): + # ( + # _, + # phase_weight, + # phase_residual, + # crystal_identity, + # ) = self.quantify_phase_pointlist( + # pointlistarray, + # position=[Rx, Ry], + # tolerance_distance=tolerance_distance, + # method=method, + # intensity_power=intensity_power, + # mask_peaks=mask_peaks, + # ) + # phase_weights[Rx, Ry, :] = phase_weight + # phase_residuals[Rx, Ry] = phase_residual + # self.phase_weights = phase_weights + # self.phase_residuals = phase_residuals + # self.crystal_identity = crystal_identity + # return + # else: + # return TypeError("pointlistarray must be of type pointlistarray.") + # return - def quantify_phase_pointlist( - self, - pointlistarray, - position, - method="nnls", - tolerance_distance=0.08, - intensity_power=0, - mask_peaks=None, - ): - """ - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - position (tuple/list): Position of pointlist in pointlistarray - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from - - Returns: - pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - phase_weights (np.ndarray): Weights of each phase - phase_residuals (np.ndarray): Residuals - crystal_identity (list): List of lists, where the each entry represents the position in the - crystal and orientation match that is associated with the phase - weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - the first entry [0,0] in phase weights is associated with the first crystal - the first match within that crystal. [0,1] is the first crystal and the - second match within that crystal. - """ - # Things to add: - # 1. Better cost for distance from peaks in pointlists - # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? + # def quantify_phase_pointlist( + # self, + # pointlistarray, + # position, + # method="nnls", + # tolerance_distance=0.08, + # intensity_power=0, + # mask_peaks=None, + # ): + # """ + # Args: + # pointlisarray (pointlistarray): Pointlistarray to quantify phase of + # position (tuple/list): Position of pointlist in pointlistarray + # tolerance_distance (float): Distance allowed between a peak and match + # method (str): Numerical method used to quantify phase + # intensity_power (float): ... + # mask_peaks (list, optional): A pointer of which positions to mask peaks from - pointlist = pointlistarray.get_pointlist(position[0], position[1]) - pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0) - pointlist.remove(pl_mask) - # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in + # Returns: + # pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns + # phase_weights (np.ndarray): Weights of each phase + # phase_residuals (np.ndarray): Residuals + # crystal_identity (list): List of lists, where the each entry represents the position in the + # crystal and orientation match that is associated with the phase + # weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], + # the first entry [0,0] in phase weights is associated with the first crystal + # the first match within that crystal. [0,1] is the first crystal and the + # second match within that crystal. + # """ + # # Things to add: + # # 1. Better cost for distance from peaks in pointlists + # # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - if intensity_power == 0: - pl_intensities = np.ones(pointlist["intensity"].shape) - else: - pl_intensities = pointlist["intensity"] ** intensity_power - # Prepare matches for modeling - pointlist_peak_matches = [] - crystal_identity = [] - - for c in range(len(self.crystals)): - for m in range(self.orientation_maps[c].num_matches): - crystal_identity.append([c, m]) - phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape)) - bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - self.orientation_maps[c].get_orientation(position[0], position[1]), - ind_orientation=m, - ) - # Find the best match peak within tolerance_distance and add value in the right position - for d in range(pointlist["qx"].shape[0]): - distances = [] - for p in range(bragg_peaks_fit["qx"].shape[0]): - distances.append( - np.sqrt( - (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2 - + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2 - ) - ) - ind = np.where(distances == np.min(distances))[0][0] - - # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - if distances[ind] <= tolerance_distance: - ## Somewhere in this if statement is probably where better distances from the peak should be coded in - if ( - intensity_power == 0 - ): # This could potentially be a different intensity_power arg - phase_peak_match_intensities[d] = 1 ** ( - (tolerance_distance - distances[ind]) - / tolerance_distance - ) - else: - phase_peak_match_intensities[d] = bragg_peaks_fit[ - "intensity" - ][ind] ** ( - (tolerance_distance - distances[ind]) - / tolerance_distance - ) - else: - ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - continue - - pointlist_peak_matches.append(phase_peak_match_intensities) - pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - pointlist_peak_intensity_matches = ( - pointlist_peak_intensity_matches.reshape( - pl_intensities.shape[0], - pointlist_peak_intensity_matches.shape[-1], - ) - ) + # pointlist = pointlistarray.get_pointlist(position[0], position[1]) + # pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0) + # pointlist.remove(pl_mask) + # # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in - if len(pointlist["qx"]) > 0: - if mask_peaks is not None: - for i in range(len(mask_peaks)): - if mask_peaks[i] == None: # noqa: E711 - continue - inds_mask = np.where( - pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 - )[0] - for mask in range(len(inds_mask)): - pointlist_peak_intensity_matches[inds_mask[mask], i] = 0 + # if intensity_power == 0: + # pl_intensities = np.ones(pointlist["intensity"].shape) + # else: + # pl_intensities = pointlist["intensity"] ** intensity_power + # # Prepare matches for modeling + # pointlist_peak_matches = [] + # crystal_identity = [] - if method == "nnls": - phase_weights, phase_residuals = nnls( - pointlist_peak_intensity_matches, pl_intensities - ) + # for c in range(len(self.crystals)): + # for m in range(self.orientation_maps[c].num_matches): + # crystal_identity.append([c, m]) + # phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape)) + # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + # self.orientation_maps[c].get_orientation(position[0], position[1]), + # ind_orientation=m, + # ) + # # Find the best match peak within tolerance_distance and add value in the right position + # for d in range(pointlist["qx"].shape[0]): + # distances = [] + # for p in range(bragg_peaks_fit["qx"].shape[0]): + # distances.append( + # np.sqrt( + # (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2 + # + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2 + # ) + # ) + # ind = np.where(distances == np.min(distances))[0][0] - elif method == "lstsq": - phase_weights, phase_residuals, rank, singluar_vals = lstsq( - pointlist_peak_intensity_matches, pl_intensities, rcond=-1 - ) - phase_residuals = np.sum(phase_residuals) - else: - raise ValueError(method + " Not yet implemented. Try nnls or lstsq.") - else: - phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) - phase_residuals = np.NaN - return ( - pointlist_peak_intensity_matches, - phase_weights, - phase_residuals, - crystal_identity, - ) + # # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value + # if distances[ind] <= tolerance_distance: + # ## Somewhere in this if statement is probably where better distances from the peak should be coded in + # if ( + # intensity_power == 0 + # ): # This could potentially be a different intensity_power arg + # phase_peak_match_intensities[d] = 1 ** ( + # (tolerance_distance - distances[ind]) + # / tolerance_distance + # ) + # else: + # phase_peak_match_intensities[d] = bragg_peaks_fit[ + # "intensity" + # ][ind] ** ( + # (tolerance_distance - distances[ind]) + # / tolerance_distance + # ) + # else: + # ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled + # continue + + # pointlist_peak_matches.append(phase_peak_match_intensities) + # pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) + # pointlist_peak_intensity_matches = ( + # pointlist_peak_intensity_matches.reshape( + # pl_intensities.shape[0], + # pointlist_peak_intensity_matches.shape[-1], + # ) + # ) + + # if len(pointlist["qx"]) > 0: + # if mask_peaks is not None: + # for i in range(len(mask_peaks)): + # if mask_peaks[i] == None: # noqa: E711 + # continue + # inds_mask = np.where( + # pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 + # )[0] + # for mask in range(len(inds_mask)): + # pointlist_peak_intensity_matches[inds_mask[mask], i] = 0 + + # if method == "nnls": + # phase_weights, phase_residuals = nnls( + # pointlist_peak_intensity_matches, pl_intensities + # ) + + # elif method == "lstsq": + # phase_weights, phase_residuals, rank, singluar_vals = lstsq( + # pointlist_peak_intensity_matches, pl_intensities, rcond=-1 + # ) + # phase_residuals = np.sum(phase_residuals) + # else: + # raise ValueError(method + " Not yet implemented. Try nnls or lstsq.") + # else: + # phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) + # phase_residuals = np.NaN + # return ( + # pointlist_peak_intensity_matches, + # phase_weights, + # phase_residuals, + # crystal_identity, + # ) # def plot_peak_matches( # self, From 304e35b8688c45c16bc4edf02aff81f59b3ab3d9 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Apr 2024 08:06:21 -0700 Subject: [PATCH 20/64] adding single phase method --- py4DSTEM/process/diffraction/crystal_phase.py | 341 ++++++++++-------- 1 file changed, 200 insertions(+), 141 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 508a01c79..bc5f10721 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -78,20 +78,22 @@ def quantify_single_pattern( pointlistarray: PointListArray, xy_position = (0,0), corr_kernel_size = 0.04, + sigma_excitation_error = 0.02, + power_experiment = 0.25, + power_calculated = 0.25, + max_number_patterns = 3, + single_phase = False, allow_strain = True, - # corr_distance_scale = 1.0, + strain_iterations = 5, + strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, - max_number_phases = 1, - sigma_excitation_error = 0.04, - power_experiment = 0.25, - power_calculated = 0.25, plot_result = True, plot_only_nonzero_phases = True, plot_unmatched_peaks = False, plot_correlation_radius = False, scale_markers_experiment = 40, - scale_markers_calculated = 500, + scale_markers_calculated = 200, crystal_inds_plot = None, phase_colors = None, figsize = (10,7), @@ -161,6 +163,8 @@ def quantify_single_pattern( basis = np.zeros((intensity.shape[0], self.num_fits)) if allow_strain: m_strains = np.zeros((self.num_fits,2,2)) + m_strains[:,0,0] = 1.0 + m_strains[:,1,1] = 1.0 # kernel radius squared radius_max_2 = corr_kernel_size**2 @@ -202,56 +206,71 @@ def quantify_single_pattern( matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') if allow_strain: - # Initial peak pairing to find best-fit strain distortion - pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') - pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') - for a1 in range(bragg_peaks_fit.data.shape[0]): - dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ - + (bragg_peaks_fit.data['qy'][a1] - qy)**2 - ind_min = np.argmin(dist2) - val_min = dist2[ind_min] - - if val_min < radius_max_2: - pair_sub[a1] = True - pair_inds[a1] = ind_min - - # calculate best-fit strain tensor, weighted by the intensities. - # requires at least 4 peak pairs - if np.sum(pair_sub) >= 4: - pair_basis = np.vstack(( - qx[pair_inds[pair_sub]], - qy[pair_inds[pair_sub]], - )).T - pair_obs = np.vstack(( - bragg_peaks_fit.data['qx'][pair_sub], - bragg_peaks_fit.data['qy'][pair_sub], - )).T - - # weights - dists = np.sqrt( - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) - weights = np.sqrt( - int_fit[pair_sub] * intensity[pair_inds[pair_sub]] - ) * (1 - dists / corr_kernel_size) - - # wtrain tensor - m_strain = np.linalg.lstsq( - pair_basis * weights[:,None], - pair_obs * weights[:,None], - rcond = None, - )[0] - m_strains[a0] = m_strain - - # Transformed peak positions - qx_copy = qx.copy() - qy_copy = qy.copy() - qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] - - # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) - # print(np.round(dist,4)) + for a1 in range(strain_iterations): + # Initial peak pairing to find best-fit strain distortion + pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') + pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + pair_sub[a1] = True + pair_inds[a1] = ind_min + + # calculate best-fit strain tensor, weighted by the intensities. + # requires at least 4 peak pairs + if np.sum(pair_sub) >= 4: + # pair_obs = bragg_peaks_fit.data[['qx','qy']][pair_sub] + pair_basis = np.vstack(( + bragg_peaks_fit.data['qx'][pair_sub], + bragg_peaks_fit.data['qy'][pair_sub], + )).T + pair_obs = np.vstack(( + qx[pair_inds[pair_sub]], + qy[pair_inds[pair_sub]], + )).T + + # weights + dists = np.sqrt( + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) + weights = np.sqrt( + int_fit[pair_sub] * intensity[pair_inds[pair_sub]] + ) * (1 - dists / corr_kernel_size) + # weights = 1 - dists / corr_kernel_size + + # strain tensor + m_strain = np.linalg.lstsq( + pair_basis * weights[:,None], + pair_obs * weights[:,None], + rcond = None, + )[0] + + # Clamp strains to be within the user-specified limit + m_strain = np.clip( + m_strain, + np.eye(2) - strain_max, + np.eye(2) + strain_max, + ) + m_strains[a0] *= m_strain + + # Transformed peak positions + qx_copy = bragg_peaks_fit.data['qx'] + qy_copy = bragg_peaks_fit.data['qy'] + bragg_peaks_fit.data['qx'] = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + bragg_peaks_fit.data['qy'] = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + # qx_copy = qx.copy() + # qy_copy = qy.copy() + # qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + # qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + + # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ + # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) + # print(np.round(dist,4)) # Loop over all peaks, pair experiment to library for a1 in range(bragg_peaks_fit.data.shape[0]): @@ -289,25 +308,53 @@ def quantify_single_pattern( else: obs = intensity - # Solve for phase coefficients + # Solve for phase weight coefficients try: phase_weights = np.zeros(self.num_fits) - inds_solve = np.ones(self.num_fits,dtype='bool') - search = True - while search is True: - phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], - obs, + if single_phase: + # loop through each crystal structure and determine the best fit structure, + # which can contain multiple orienations. + crystal_res = np.zeros(self.num_crystals) + + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + + phase_weights_cand, phase_residual_cand = nnls( + basis[:,sub], + obs, + ) + phase_weights[sub] = phase_weights_cand + crystal_res[a0] = phase_residual_cand + + ind_best_fit = np.argmin(crystal_res) + phase_residual = crystal_res[ind_best_fit] + sub = np.logical_not( + self.crystal_identity[:,0] == ind_best_fit ) + phase_weights[sub] = 0.0 - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_phases: - phase_weights[inds_solve] = phase_weights_cand - phase_residual = phase_residual_cand - search = False - else: - inds = np.where(inds_solve)[0] - inds_solve[inds[np.argmin(phase_weights_cand)]] = False + # Estimate reliability as difference between best fit and 2nd best fit + crystal_res = np.sort(crystal_res) + phase_reliability = crystal_res[1] - crystal_res[0] + + else: + # Allow all crystals and orientation matches in the pattern + inds_solve = np.ones(self.num_fits,dtype='bool') + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_weights[inds_solve] = phase_weights_cand + phase_residual = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False except: phase_weights = np.zeros(self.num_fits) @@ -322,7 +369,7 @@ def quantify_single_pattern( c = self.crystal_identity[a0,0] m = self.crystal_identity[a0,1] line = '{:>12} {:>8} {:<12}'.format( - np.round(phase_weights[a0],decimals=2), + f'{phase_weights[a0]:.2f}', m, self.crystal_names[c] ) @@ -330,7 +377,17 @@ def quantify_single_pattern( print('\033[1m' + line + '\033[0m') else: print(line) - # print() + print('----------------------------') + line = '{:>12} {:>15}'.format( + f'{sum(phase_weights):.2f}', + 'fit total' + ) + print('\033[1m' + line + '\033[0m') + line = '{:>12} {:>15}'.format( + f'{phase_residual:.2f}', + 'fit residual' + ) + print(line) # Plotting if plot_result: @@ -380,7 +437,7 @@ def quantify_single_pattern( "family": "sans-serif", "fontweight": "normal", "color": "k", - "size": 14, + "size": 12, } if plot_correlation_radius: ax_leg.plot( @@ -422,6 +479,7 @@ def quantify_single_pattern( # )) mvals = ['v','^','<','>','d','s',] + count_leg = 0 for a0 in range(self.num_fits): c = self.crystal_identity[a0,0] m = self.crystal_identity[a0,1] @@ -431,83 +489,84 @@ def quantify_single_pattern( qx_fit = library_peaks[a0].data['qx'] qy_fit = library_peaks[a0].data['qy'] - # if allow_strain: - # m_strain = m_strains[a0] - # # Transformed peak positions - # qx_copy = qx_fit.copy() - # qy_copy = qy_fit.copy() - # qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - # qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + if allow_strain: + m_strain = m_strains[a0] + # Transformed peak positions + qx_copy = qx_fit.copy() + qy_copy = qy_fit.copy() + qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] + qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] int_fit = library_int[a0] matches_fit = library_matches[a0] if plot_only_nonzero_phases is False or phase_weights[a0] > 0: - if np.mod(m,2) == 0: + # if np.mod(m,2) == 0: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + if plot_unmatched_peaks: ax.scatter( - qy_fit[matches_fit], - qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], marker = mvals[c], facecolor = phase_colors[c,:], ) - if plot_unmatched_peaks: - ax.scatter( - qy_fit[np.logical_not(matches_fit)], - qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) - - # legend - ax_leg.scatter( - 0, - dx_leg*(a0+1), - s = 200, - marker = mvals[c], - facecolor = phase_colors[c,:], - ) - else: - ax.scatter( - qy_fit[matches_fit], - qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), - # facecolors = (1,1,1,0.5), - linewidth = 2, - ) - if plot_unmatched_peaks: - ax.scatter( - qy_fit[np.logical_not(matches_fit)], - qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (1,1,1,0.5), - linewidth = 2, - ) - - # legend - ax_leg.scatter( - 0, - dx_leg*(a0+1), - s = 200, - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), - # facecolors = (1,1,1,0.5), - ) - # legend text - ax_leg.text( - dy_leg, - (a0+1)*dx_leg, - self.crystal_names[c], - **text_params) + # legend + if m == 0: + ax_leg.text( + dy_leg, + (count_leg+1)*dx_leg, + self.crystal_names[c], + **text_params) + ax_leg.scatter( + 0, + (count_leg+1) * dx_leg, + s = 200, + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + count_leg += 1 + # else: + # ax.scatter( + # qy_fit[matches_fit], + # qx_fit[matches_fit], + # s = scale_markers_calculated * int_fit[matches_fit], + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # # facecolors = (1,1,1,0.5), + # linewidth = 2, + # ) + # if plot_unmatched_peaks: + # ax.scatter( + # qy_fit[np.logical_not(matches_fit)], + # qx_fit[np.logical_not(matches_fit)], + # s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (1,1,1,0.5), + # linewidth = 2, + # ) + + # # legend + # ax_leg.scatter( + # 0, + # dx_leg*(a0+1), + # s = 200, + # marker = mvals[c], + # edgecolors = uvals[c,:], + # facecolors = (uvals[c,0],uvals[c,1],uvals[c,2],0.3), + # # facecolors = (1,1,1,0.5), + # ) + # appearance @@ -531,7 +590,7 @@ def quantify_phase( allow_strain = True, include_false_positives = True, weight_false_positives = 1.0, - max_number_phases = 2, + max_number_patterns = 2, sigma_excitation_error = 0.02, power_experiment = 0.25, power_calculated = 0.25, @@ -571,7 +630,7 @@ def quantify_phase( # corr_distance_scale = corr_distance_scale, include_false_positives = include_false_positives, weight_false_positives = weight_false_positives, - max_number_phases = max_number_phases, + max_number_patterns = max_number_patterns, sigma_excitation_error = sigma_excitation_error, power_experiment = power_experiment, power_calculated = power_calculated, From 909b84f4140ca6c0f543df9e93fc5be2e0528dfc Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Apr 2024 08:32:47 -0700 Subject: [PATCH 21/64] updating plotting function --- py4DSTEM/process/diffraction/crystal_phase.py | 130 +++++++++--------- 1 file changed, 68 insertions(+), 62 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index bc5f10721..b997fb9d1 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -263,15 +263,6 @@ def quantify_single_pattern( bragg_peaks_fit.data['qx'] = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] bragg_peaks_fit.data['qy'] = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] - # qx_copy = qx.copy() - # qy_copy = qy.copy() - # qx = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - # qy = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] - - # dist = np.mean(np.sqrt((bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - # (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2)) - # print(np.round(dist,4)) - # Loop over all peaks, pair experiment to library for a1 in range(bragg_peaks_fit.data.shape[0]): dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ @@ -356,9 +347,33 @@ def quantify_single_pattern( inds = np.where(inds_solve)[0] inds_solve[inds[np.argmin(phase_weights_cand)]] = False + # Estimate the phase reliability + inds_solve = np.ones(self.num_fits,dtype='bool') + inds_solve[phase_weights > 1e-8] = False + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_residual_2nd = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + + # print(inds_solve) + # phase_weights_cand, phase_residual_cand = nnls( + # basis[:,inds_solve], + # obs, + # ) + phase_reliability = phase_residual_2nd - phase_residual + except: phase_weights = np.zeros(self.num_fits) phase_residual = np.sqrt(np.sum(intensity**2)) + phase_reliability = 0.0 if verbose: ind_max = np.argmax(phase_weights) @@ -578,22 +593,24 @@ def quantify_single_pattern( ax_leg.set_axis_off() if returnfig: - return phase_weights, phase_residual, int_total, fig, ax + return phase_weights, phase_residual, phase_reliability, int_total, fig, ax else: - return phase_weights, phase_residual, int_total + return phase_weights, phase_residual, phase_reliability, int_total def quantify_phase( self, pointlistarray: PointListArray, corr_kernel_size = 0.04, - # corr_distance_scale = 1.0, - allow_strain = True, - include_false_positives = True, - weight_false_positives = 1.0, - max_number_patterns = 2, sigma_excitation_error = 0.02, power_experiment = 0.25, power_calculated = 0.25, + max_number_patterns = 3, + single_phase = False, + allow_strain = True, + strain_iterations = 5, + strain_max = 0.02, + include_false_positives = True, + weight_false_positives = 1.0, progress_bar = True, ): """ @@ -610,36 +627,44 @@ def quantify_phase( pointlistarray.shape[0], pointlistarray.shape[1], )) + self.phase_reliability = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) self.int_total = np.zeros(( pointlistarray.shape[0], pointlistarray.shape[1], )) + self.single_phase = single_phase for rx, ry in tqdmnd( *pointlistarray.shape, - desc="Matching Orientations", + desc="Quantifying Phase", unit=" PointList", disable=not progress_bar, ): # calculate phase weights - phase_weights, phase_residual, int_peaks = self.quantify_single_pattern( + phase_weights, phase_residual, phase_reliability, int_peaks = self.quantify_single_pattern( pointlistarray = pointlistarray, xy_position = (rx,ry), corr_kernel_size = corr_kernel_size, - allow_strain = allow_strain, - # corr_distance_scale = corr_distance_scale, - include_false_positives = include_false_positives, - weight_false_positives = weight_false_positives, - max_number_patterns = max_number_patterns, sigma_excitation_error = sigma_excitation_error, power_experiment = power_experiment, power_calculated = power_calculated, + max_number_patterns = max_number_patterns, + single_phase = single_phase, + allow_strain = allow_strain, + strain_iterations = strain_iterations, + strain_max = strain_max, + include_false_positives = include_false_positives, + weight_false_positives = weight_false_positives, plot_result = False, verbose = False, returnfig = False, ) self.phase_weights[rx,ry] = phase_weights self.phase_residuals[rx,ry] = phase_residual + self.phase_reliability[rx,ry] = phase_reliability self.int_total[rx,ry] = int_peaks @@ -877,7 +902,7 @@ def plot_phase_maps( def plot_dominant_phase( self, - rel_range = (0.0,1.0), + reliability_range = (0.0,1.0), sigma = 0.0, phase_colors = None, figsize = (6,6), @@ -918,52 +943,33 @@ def plot_dominant_phase( mode = 'nearest', ) - # find highest correlation score for each crystal and match index for a0 in range(self.num_crystals): sub = phase_sig[a0] > phase_corr phase_map[sub] = a0 phase_corr[sub] = phase_sig[a0][sub] - # find the second correlation score for each crystal and match index - for a0 in range(self.num_crystals): - corr = phase_sig[a0].copy() - corr[phase_map==a0] = 0.0 - sub = corr > phase_corr_2nd - phase_corr_2nd[sub] = corr[sub] - - - - + if self.single_phase: + phase_scale = np.clip( + (self.phase_reliability - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + 0, + 1) - # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): - # sub = crystals[a0].orientation_map.corr[:,:,a1] > phase_corr - # phase_map[sub] = a0 - # phase_corr[sub] = crystals[a0].orientation_map.corr[:,:,a1][sub] + else: - # # find the second correlation score for each crystal and match index - # for a0 in range(len(crystals)): - # for a1 in range(crystals[a0].orientation_map.corr.shape[2]): - # corr = crystals[a0].orientation_map.corr[:,:,a1].copy() - # corr[phase_map==a0] = 0.0 - # sub = corr > phase_corr_2nd - # phase_corr_2nd[sub] = corr[sub] - - # Estimate the reliability - phase_rel = phase_corr - phase_corr_2nd - - # # Generate a color plotting image - # c_vals = np.array([ - # [0.7,0.5,0.3], - # [1,0,0], - # [0,0.7,0], - # [0,0.7,1], - # [0,0.7,1], - # ]) - phase_scale = np.clip( - (phase_rel - rel_range[0]) / (rel_range[1] - rel_range[0]), - 0, - 1) + # find the second correlation score for each crystal and match index + for a0 in range(self.num_crystals): + corr = phase_sig[a0].copy() + corr[phase_map==a0] = 0.0 + sub = corr > phase_corr_2nd + phase_corr_2nd[sub] = corr[sub] + + # Estimate the reliability + phase_rel = phase_corr - phase_corr_2nd + phase_scale = np.clip( + (phase_rel - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + 0, + 1) phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) From 932fe5689b24ccac60bfbd48fc77a1fcb2c9377d Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Apr 2024 13:48:34 -0700 Subject: [PATCH 22/64] Adding new weights for cost function --- py4DSTEM/process/diffraction/crystal_phase.py | 73 ++++++++++++------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index b997fb9d1..45219f10f 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -84,7 +84,7 @@ def quantify_single_pattern( max_number_patterns = 3, single_phase = False, allow_strain = True, - strain_iterations = 5, + strain_iterations = 3, strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, @@ -270,17 +270,33 @@ def quantify_single_pattern( ind_min = np.argmin(dist2) val_min = dist2[ind_min] - if val_min < radius_max_2: - # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size - # weight = 1 + corr_distance_scale * \ - # np.sqrt(dist2[ind_min]) / corr_kernel_size - # basis[ind_min,a0] = weight * int_fit[a1] - basis[ind_min,a0] = int_fit[a1] - if plot_result: + if include_false_positives: + weight = np.clip(1 - np.sqrt(dist2[ind_min]) / corr_kernel_size,0,1) + basis[ind_min,a0] = int_fit[a1] * weight + unpaired_peaks.append([ + a0, + int_fit[a1] * (1 - weight), + ]) + if weight > 1e-8 and plot_result: matches[a1] = True - elif include_false_positives: - # unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) - unpaired_peaks.append([a0,int_fit[a1]]) + else: + if val_min < radius_max_2: + basis[ind_min,a0] = int_fit[a1] + if plot_result: + matches[a1] = True + + + # if val_min < radius_max_2: + # # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + # # weight = 1 + corr_distance_scale * \ + # # np.sqrt(dist2[ind_min]) / corr_kernel_size + # # basis[ind_min,a0] = weight * int_fit[a1] + # basis[ind_min,a0] = int_fit[a1] + # if plot_result: + # matches[a1] = True + # elif include_false_positives: + # # unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) + # unpaired_peaks.append([a0,int_fit[a1]]) if plot_result: library_peaks.append(bragg_peaks_fit) @@ -347,34 +363,39 @@ def quantify_single_pattern( inds = np.where(inds_solve)[0] inds_solve[inds[np.argmin(phase_weights_cand)]] = False + # Estimate the phase reliability inds_solve = np.ones(self.num_fits,dtype='bool') inds_solve[phase_weights > 1e-8] = False - search = True - while search is True: + + if np.all(inds_solve == False): + phase_reliability = 0.0 + else: + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_residual_2nd = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + phase_weights_cand, phase_residual_cand = nnls( basis[:,inds_solve], obs, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: - phase_residual_2nd = phase_residual_cand - search = False - else: - inds = np.where(inds_solve)[0] - inds_solve[inds[np.argmin(phase_weights_cand)]] = False - - # print(inds_solve) - # phase_weights_cand, phase_residual_cand = nnls( - # basis[:,inds_solve], - # obs, - # ) - phase_reliability = phase_residual_2nd - phase_residual + phase_reliability = phase_residual_2nd - phase_residual except: phase_weights = np.zeros(self.num_fits) phase_residual = np.sqrt(np.sum(intensity**2)) phase_reliability = 0.0 + if verbose: ind_max = np.argmax(phase_weights) # print() From e306c5dcad9632bbb98b2552e4f5a803a641b9d2 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 11 Apr 2024 09:48:34 -0700 Subject: [PATCH 23/64] change to strain iteration default --- py4DSTEM/process/diffraction/crystal_phase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 45219f10f..f6d916a95 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -628,7 +628,7 @@ def quantify_phase( max_number_patterns = 3, single_phase = False, allow_strain = True, - strain_iterations = 5, + strain_iterations = 3, strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, From 217d71f60761bac311a28e2ce9bf8e08046b92c0 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 28 Apr 2024 14:43:04 -0700 Subject: [PATCH 24/64] precession electron diffraction added --- py4DSTEM/process/diffraction/crystal.py | 153 ++++++++++++++++++------ 1 file changed, 115 insertions(+), 38 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index e0fe59eee..6667e6cf5 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -712,39 +712,65 @@ def generate_diffraction_pattern( foil_normal_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, sigma_excitation_error: float = 0.02, tol_excitation_error_mult: float = 3, - tol_intensity: float = 1e-4, + tol_intensity: float = 1e-5, k_max: Optional[float] = None, + precession_angle_degrees = None, keep_qz=False, return_orientation_matrix=False, ): """ Generate a single diffraction pattern, return all peaks as a pointlist. - Args: - orientation (Orientation): an Orientation class object - ind_orientation If input is an Orientation class object with multiple orientations, - this input can be used to select a specific orientation. - - orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions. - zone_axis_lattice (array): (3,) projection direction in lattice indices - proj_x_lattice (array): (3,) x-axis direction in lattice indices - zone_axis_cartesian (array): (3,) cartesian projection direction - proj_x_cartesian (array): (3,) cartesian projection direction - - foil_normal: 3 element foil normal - set to None to use zone_axis - proj_x_axis (np float vector): 3 element vector defining image x axis (vertical) - accel_voltage (float): Accelerating voltage in Volts. If not specified, - we check to see if crystal already has voltage specified. - sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): Maximum scattering vector - keep_qz (bool): Flag to return out-of-plane diffraction vectors - return_orientation_matrix (bool): Return the orientation matrix + Parameters + ---------- - Returns: - bragg_peaks (PointList): list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] - orientation_matrix (array): 3x3 orientation matrix (optional) + orientation (Orientation) + an Orientation class object + ind_orientation + If input is an Orientation class object with multiple orientations, + this input can be used to select a specific orientation. + + orientation_matrix (3,3) numpy.array + orientation matrix, where columns represent projection directions. + zone_axis_lattice (3,) numpy.array + projection direction in lattice indices + proj_x_lattice (3,) numpy.array + x-axis direction in lattice indices + zone_axis_cartesian (3,) numpy.array + cartesian projection direction + proj_x_cartesian (3,) numpy.array + cartesian projection direction + + foil_normal + 3 element foil normal - set to None to use zone_axis + proj_x_axis (3,) numpy.array + 3 element vector defining image x axis (vertical) + accel_voltage (float) + Accelerating voltage in Volts. If not specified, + we check to see if crystal already has voltage specified. + + sigma_excitation_error (float) + sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms + tol_excitation_error_mult (float) + tolerance in units of sigma for s_g inclusion + tol_intensity (numpy float) + tolerance in intensity units for inclusion of diffraction spots + k_max (float) + Maximum scattering vector + precession_angle_degrees (float) + Precession angle for library calculation. Set to None for no precession. + keep_qz (bool) + Flag to return out-of-plane diffraction vectors + return_orientation_matrix (bool) + Return the orientation matrix + + Returns + ---------- + bragg_peaks (PointList) + list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] + orientation_matrix (array) + 3x3 orientation matrix (optional) + """ if not (hasattr(self, "wavelength") and hasattr(self, "accel_voltage")): @@ -779,17 +805,27 @@ def generate_diffraction_pattern( # Calculate excitation errors if foil_normal is None: - sg = self.excitation_errors(g) + sg = self.excitation_errors( + g, + precession_angle_degrees = precession_angle_degrees, + ) else: foil_normal = ( orientation_matrix.T @ (-1 * foil_normal[:, None] / np.linalg.norm(foil_normal)) ).ravel() - sg = self.excitation_errors(g, foil_normal) + sg = self.excitation_errors( + g, + foil_normal = foil_normal, + precession_angle_degrees = precession_angle_degrees, + ) # Threshold for inclusion in diffraction pattern sg_max = sigma_excitation_error * tol_excitation_error_mult - keep = np.abs(sg) <= sg_max + if precession_angle_degrees is None: + keep = np.abs(sg) <= sg_max + else: + keep = np.min(np.abs(sg),axis=1) <= sg_max # Maximum scattering angle cutoff if k_max is not None: @@ -799,9 +835,17 @@ def generate_diffraction_pattern( g_diff = g[:, keep] # Diffracted peak intensities and labels - g_int = self.struct_factors_int[keep] * np.exp( - (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) - ) + if precession_angle_degrees is None: + g_int = self.struct_factors_int[keep] * np.exp( + (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) + ) + else: + g_int = self.struct_factors_int[keep] * np.mean( + np.exp( + (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) + ), + axis = 1, + ) hkl = self.hkl[:, keep] # Intensity tolerance @@ -1323,21 +1367,54 @@ def excitation_errors( self, g, foil_normal=None, + precession_angle_degrees=None, + precession_steps=180, ): """ Calculate the excitation errors, assuming k0 = [0, 0, -1/lambda]. If foil normal is not specified, we assume it is [0,0,-1]. + + Precession is currently implemented using numerical integration. """ - if foil_normal is None: - return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( - 2 - 2 * self.wavelength * g[2, :] - ) + + if precession_angle_degrees is None: + if foil_normal is None: + return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( + 2 - 2 * self.wavelength * g[2, :] + ) + else: + return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( + 2 * self.wavelength * np.sum(g * foil_normal[:, None], axis=0) + - 2 * foil_normal[2] + ) + else: - return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / ( - 2 * self.wavelength * np.sum(g * foil_normal[:, None], axis=0) - - 2 * foil_normal[2] + t = np.deg2rad(precession_angle_degrees) + p = np.linspace( + 0, + 2.0*np.pi, + precession_steps, + endpoint = False, + ) + if foil_normal is None: + foil_normal = np.array((0.0,0.0,-1.0)) + + k = np.reshape( + (-1/self.wavelength) * np.vstack(( + np.sin(t)*np.cos(p), + np.sin(t)*np.sin(p), + np.cos(t)*np.ones(p.size), + )), + (3,1,p.size), ) + term1 = np.sum( (g[:,:,None] + k) * foil_normal[:,None,None], axis=0) + term2 = np.sum( (g[:,:,None] + 2*k) * g[:,:,None], axis=0) + sg = np.sqrt(term1**2 - term2) - term1 + + return sg + + def calculate_bragg_peak_histogram( self, bragg_peaks, From 264deba5e5c9f52554962539c6d5108b207b807a Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 28 Apr 2024 14:55:52 -0700 Subject: [PATCH 25/64] New default values for diffraction sigma and intensity threshold --- py4DSTEM/process/diffraction/crystal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 6667e6cf5..9694111c9 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -710,7 +710,7 @@ def generate_diffraction_pattern( zone_axis_cartesian: Optional[np.ndarray] = None, proj_x_cartesian: Optional[np.ndarray] = None, foil_normal_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, - sigma_excitation_error: float = 0.02, + sigma_excitation_error: float = 0.01, tol_excitation_error_mult: float = 3, tol_intensity: float = 1e-5, k_max: Optional[float] = None, From ba5e0a215c3c5fb59c01729b4352f72fa71a24e1 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 28 Apr 2024 15:07:15 -0700 Subject: [PATCH 26/64] Adding precession to orientation plans, update docstring --- py4DSTEM/process/diffraction/crystal_ACOM.py | 101 ++++++++++++------- 1 file changed, 67 insertions(+), 34 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index fc6e691c3..f81ad53cc 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -29,6 +29,7 @@ def orientation_plan( corr_kernel_size: float = 0.08, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling + precession_angle_degrees = None, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -41,39 +42,60 @@ def orientation_plan( """ Calculate the rotation basis arrays for an SO(3) rotation correlogram. - Args: - zone_axis_range (float): Row vectors give the range for zone axis orientations. - If user specifies 2 vectors (2x3 array), we start at [0,0,1] - to make z-x-z rotation work. - If user specifies 3 vectors (3x3 array), plan will span these vectors. - Setting to 'full' as a string will use a hemispherical range. - Setting to 'half' as a string will use a quarter sphere range. - Setting to 'fiber' as a string will make a spherical cap around a given vector. - Setting to 'auto' will use pymatgen to determine the point group symmetry - of the structure and choose an appropriate zone_axis_range - angle_step_zone_axis (float): Approximate angular step size for zone axis search [degrees] - angle_coarse_zone_axis (float): Coarse step size for zone axis search [degrees]. Setting to - None uses the same value as angle_step_zone_axis. - angle_refine_range (float): Range of angles to use for zone axis refinement. Setting to - None uses same value as angle_coarse_zone_axis. - - angle_step_in_plane (float): Approximate angular step size for in-plane rotation [degrees] - accel_voltage (float): Accelerating voltage for electrons [Volts] - corr_kernel_size (float): Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] - radial_power (float): Power for scaling the correlation intensity as a function of the peak radius - intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity - calculate_correlation_array (bool): Set to false to skip calculating the correlation array. - This is useful when we only want the angular range / rotation matrices. - tol_peak_delete (float): Distance to delete peaks for multiple matches. - Default is kernel_size * 0.5 - tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] - fiber_axis (float): (3,) vector specifying the fiber axis - fiber_angles (float): (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] - cartesian_directions (bool): When set to true, all zone axes and projection directions - are specified in Cartesian directions. - figsize (float): (2,) vector giving the figure size - CUDA (bool): Use CUDA for the Fourier operations. - progress_bar (bool): If false no progress bar is displayed + Parameters + ---------- + zone_axis_range (float): + Row vectors give the range for zone axis orientations. + If user specifies 2 vectors (2x3 array), we start at [0,0,1] + to make z-x-z rotation work. + If user specifies 3 vectors (3x3 array), plan will span these vectors. + Setting to 'full' as a string will use a hemispherical range. + Setting to 'half' as a string will use a quarter sphere range. + Setting to 'fiber' as a string will make a spherical cap around a given vector. + Setting to 'auto' will use pymatgen to determine the point group symmetry + of the structure and choose an appropriate zone_axis_range + angle_step_zone_axis (float): + Approximate angular step size for zone axis search [degrees] + angle_coarse_zone_axis (float): + Coarse step size for zone axis search [degrees]. Setting to + None uses the same value as angle_step_zone_axis. + angle_refine_range (float): + Range of angles to use for zone axis refinement. Setting to + None uses same value as angle_coarse_zone_axis. + + angle_step_in_plane (float): + Approximate angular step size for in-plane rotation [degrees] + accel_voltage (float): + Accelerating voltage for electrons [Volts] + corr_kernel_size (float): + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + radial_power (float): + Power for scaling the correlation intensity as a function of the peak radius + intensity_power (float): + Power for scaling the correlation intensity as a function of the peak intensity + precession_angle_degrees (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + calculate_correlation_array (bool): + Set to false to skip calculating the correlation array. + This is useful when we only want the angular range / rotation matrices. + tol_peak_delete (float): + Distance to delete peaks for multiple matches. + Default is kernel_size * 0.5 + tol_distance (float): + Distance tolerance for radial shell assignment [1/Angstroms] + fiber_axis (float): + (3,) vector specifying the fiber axis + fiber_angles (float): + (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] + cartesian_directions (bool): + When set to true, all zone axes and projection directions + are specified in Cartesian directions. + figsize (float): + (2,) vector giving the figure size + CUDA (bool): + Use CUDA for the Fourier operations. + progress_bar (bool): + If false no progress bar is displayed, """ # Check to make sure user has calculated the structure factors if needed @@ -717,7 +739,18 @@ def orientation_plan( ): # reciprocal lattice spots and excitation errors g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all - sg = self.excitation_errors(g) + if precession_angle_degrees is None: + sg = self.excitation_errors(g) + else: + sg = np.min( + np.abs( + self.excitation_errors( + g, + precession_angle_degrees = precession_angle_degrees, + ), + ), + axis = 1, + ) # Keep only points that will contribute to this orientation plan slice keep = np.abs(sg) < self.orientation_kernel_size From 3ee1adf8d6fa0b039221e1b1516f9f31ca7525fb Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 6 May 2024 14:04:35 -0700 Subject: [PATCH 27/64] adding precession to orientation plans --- py4DSTEM/process/diffraction/crystal_ACOM.py | 50 +++++++++++++------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index f81ad53cc..f7406b590 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -27,9 +27,10 @@ def orientation_plan( angle_step_in_plane: float = 2.0, accel_voltage: float = 300e3, corr_kernel_size: float = 0.08, + sigma_excitation_error: float = 0.01, + precession_angle_degrees = None, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling - precession_angle_degrees = None, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -68,13 +69,17 @@ def orientation_plan( accel_voltage (float): Accelerating voltage for electrons [Volts] corr_kernel_size (float): - Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error (float): + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + radial_power (float): Power for scaling the correlation intensity as a function of the peak radius intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity - precession_angle_degrees (float) - Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). calculate_correlation_array (bool): Set to false to skip calculating the correlation array. This is useful when we only want the angular range / rotation matrices. @@ -108,6 +113,7 @@ def orientation_plan( # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) + self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) if tol_peak_delete is None: self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5 else: @@ -739,25 +745,35 @@ def orientation_plan( ): # reciprocal lattice spots and excitation errors g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all - if precession_angle_degrees is None: - sg = self.excitation_errors(g) - else: - sg = np.min( - np.abs( - self.excitation_errors( - g, - precession_angle_degrees = precession_angle_degrees, - ), - ), - axis = 1, - ) + # if precession_angle_degrees is None: + sg = self.excitation_errors(g) + # else: + # sg = np.min( + # np.abs( + # self.excitation_errors( + # g, + # precession_angle_degrees = precession_angle_degrees, + # ), + # ), + # axis = 1, + # ) # Keep only points that will contribute to this orientation plan slice - keep = np.abs(sg) < self.orientation_kernel_size + keep = np.logical_and( + np.abs(sg) < self.orientation_kernel_size, + self.orientation_shell_index >= 0, + ) # in-plane rotation angle phi = np.arctan2(g[1, :], g[0, :]) + # calculate intensity of spots + # if precession_angle_degrees is None: + # Ig = np.exp(sg**2/(-2*sigma_excitation_error**2)) + # else: + # pass + + # Loop over all peaks for a1 in np.arange(self.g_vec_all.shape[1]): ind_radial = self.orientation_shell_index[a1] From 19058a5b4cd1664352cf72ad82f4d3d2f18fc6eb Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 6 May 2024 16:29:02 -0700 Subject: [PATCH 28/64] precession added to orientation plans! --- py4DSTEM/process/diffraction/crystal_ACOM.py | 104 ++++++++++++------- py4DSTEM/process/diffraction/crystal_viz.py | 1 + 2 files changed, 66 insertions(+), 39 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index f7406b590..960461669 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -113,7 +113,11 @@ def orientation_plan( # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) - self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) + if precession_angle_degrees is None: + self.orientation_precession_angle_degrees = None + else: + self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) + self.orientation_precession_angle = np.deg2rad(np.asarray(precession_angle_degrees)) if tol_peak_delete is None: self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5 else: @@ -764,45 +768,67 @@ def orientation_plan( self.orientation_shell_index >= 0, ) - # in-plane rotation angle - phi = np.arctan2(g[1, :], g[0, :]) - # calculate intensity of spots - # if precession_angle_degrees is None: - # Ig = np.exp(sg**2/(-2*sigma_excitation_error**2)) - # else: - # pass - - - # Loop over all peaks - for a1 in np.arange(self.g_vec_all.shape[1]): - ind_radial = self.orientation_shell_index[a1] - - if keep[a1] and ind_radial >= 0: - # 2D orientation plan - self.orientation_ref[a0, ind_radial, :] += ( - np.power(self.orientation_shell_radii[ind_radial], radial_power) - * np.power(self.struct_factors_int[a1], intensity_power) - * np.maximum( - 1 - - np.sqrt( - sg[a1] ** 2 - + ( - ( - np.mod( - self.orientation_gamma - phi[a1] + np.pi, - 2 * np.pi, - ) - - np.pi - ) - * self.orientation_shell_radii[ind_radial] - ) - ** 2 - ) - / self.orientation_kernel_size, - 0, - ) - ) + if precession_angle_degrees is None: + Ig = np.exp(sg[keep]**2/(-2*sigma_excitation_error**2)) + else: + # precession extension + prec = np.cos(np.linspace(0,2*np.pi,90,endpoint=False)) + dsg = np.tan(self.orientation_precession_angle) * np.sum(g[:2,keep]**2,axis=0) + Ig = np.mean(np.exp((sg[keep,None] + dsg[:,None]*prec[None,:])**2 \ + / (-2*sigma_excitation_error**2)), axis = 1) + + # in-plane rotation angle + phi = np.arctan2(g[1, keep], g[0, keep]) + phi_ind = phi / self.orientation_gamma[1] # step size of annular bins + phi_floor = np.floor(phi_ind).astype('int') + dphi = phi_ind - phi_floor + + # write intensities into orientation plan slice + radial_inds = self.orientation_shell_index[keep] + self.orientation_ref[a0, radial_inds, phi_floor] += \ + (1-dphi) * \ + np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ + np.power(self.orientation_shell_radii[radial_inds], radial_power) + self.orientation_ref[a0, radial_inds, np.mod(phi_floor+1,self.orientation_in_plane_steps)] += \ + dphi * \ + np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ + np.power(self.orientation_shell_radii[radial_inds], radial_power) + + + # # Loop over all peaks + # for a1 in np.arange(self.g_vec_all.shape[1]): + # if keep[a1]: + + + # for a1 in np.arange(self.g_vec_all.shape[1]): + # ind_radial = self.orientation_shell_index[a1] + + # if keep[a1] and ind_radial >= 0: + # # 2D orientation plan + # self.orientation_ref[a0, ind_radial, :] += ( + # np.power(self.orientation_shell_radii[ind_radial], radial_power) + # * np.power(self.struct_factors_int[a1], intensity_power) + # * np.maximum( + # 1 + # - np.sqrt( + # sg[a1] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma - phi[a1] + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * self.orientation_shell_radii[ind_radial] + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ) + # ) orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) if orientation_ref_norm > 0: diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 47df2e6ca..fdb313cbb 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -848,6 +848,7 @@ def plot_orientation_plan( bragg_peaks = self.generate_diffraction_pattern( orientation_matrix=self.orientation_rotation_matrices[index_plot, :], sigma_excitation_error=self.orientation_kernel_size / 3, + precession_angle_degrees = self.orientation_precession_angle_degrees, ) plot_diffraction_pattern( From 3e656b6df981908e96904616b6c77416528dcb4c Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 6 May 2024 19:20:31 -0700 Subject: [PATCH 29/64] Fiddling with matching routines --- py4DSTEM/process/diffraction/crystal_ACOM.py | 38 ++++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 960461669..bdff28de2 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -27,7 +27,7 @@ def orientation_plan( angle_step_in_plane: float = 2.0, accel_voltage: float = 300e3, corr_kernel_size: float = 0.08, - sigma_excitation_error: float = 0.01, + sigma_excitation_error: float = 0.02, precession_angle_degrees = None, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling @@ -1034,12 +1034,7 @@ def match_single_pattern( if np.any(sub): im_polar[ind_radial, :] = np.sum( - np.power(radius, self.orientation_radial_power) - * np.power( - np.maximum(intensity[sub, None], 0.0), - self.orientation_intensity_power, - ) - * np.maximum( + np.maximum( 1 - np.sqrt( dqr[sub, None] ** 2 @@ -1062,6 +1057,35 @@ def match_single_pattern( ), axis=0, ) + # im_polar[ind_radial, :] = np.sum( + # np.power(radius, self.orientation_radial_power) + # * np.power( + # np.maximum(intensity[sub, None], 0.0), + # self.orientation_intensity_power, + # ) + # * np.maximum( + # 1 + # - np.sqrt( + # dqr[sub, None] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma[None, :] + # - qphi[sub, None] + # + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * radius + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ), + # axis=0, + # ) # Determine the RMS signal from im_polar for the first match. # Note that we use scaling slightly below RMS so that following matches From e768a517285007dc3b2195dc471c95ef6e8d5a5f Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 7 May 2024 05:18:36 -0700 Subject: [PATCH 30/64] adding intensity scaling back into experiment --- py4DSTEM/process/diffraction/crystal_ACOM.py | 40 ++++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index bdff28de2..b1f64aa58 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -29,8 +29,9 @@ def orientation_plan( corr_kernel_size: float = 0.08, sigma_excitation_error: float = 0.02, precession_angle_degrees = None, - radial_power: float = 1.0, - intensity_power: float = 0.25, # New default intensity power scaling + power_radial: float = 1.0, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -76,10 +77,12 @@ def orientation_plan( precession_angle_degrees (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - radial_power (float): + power_radial (float): Power for scaling the correlation intensity as a function of the peak radius - intensity_power (float): - Power for scaling the correlation intensity as a function of the peak intensity + power_intensity (float): + Power for scaling the correlation intensity as a function of simulated peak intensity + power_intensity_experiment (float): + Power for scaling the correlation intensity as a function of experimental peak intensity calculate_correlation_array (bool): Set to false to skip calculating the correlation array. This is useful when we only want the angular range / rotation matrices. @@ -136,8 +139,9 @@ def orientation_plan( self.wavelength = electron_wavelength_angstrom(self.accel_voltage) # store the radial and intensity scaling to use later for generating test patterns - self.orientation_radial_power = radial_power - self.orientation_intensity_power = intensity_power + self.orientation_power_radial = power_radial + self.orientation_power_intensity = power_intensity + self.orientation_power_intensity_experiment = power_intensity_experiment # Calculate the ratio between coarse and fine refinement if angle_coarse_zone_axis is not None: @@ -788,12 +792,12 @@ def orientation_plan( radial_inds = self.orientation_shell_index[keep] self.orientation_ref[a0, radial_inds, phi_floor] += \ (1-dphi) * \ - np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ - np.power(self.orientation_shell_radii[radial_inds], radial_power) + np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ + np.power(self.orientation_shell_radii[radial_inds], power_radial) self.orientation_ref[a0, radial_inds, np.mod(phi_floor+1,self.orientation_in_plane_steps)] += \ dphi * \ - np.power(self.struct_factors_int[keep] * Ig, intensity_power) * \ - np.power(self.orientation_shell_radii[radial_inds], radial_power) + np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ + np.power(self.orientation_shell_radii[radial_inds], power_radial) # # Loop over all peaks @@ -807,8 +811,8 @@ def orientation_plan( # if keep[a1] and ind_radial >= 0: # # 2D orientation plan # self.orientation_ref[a0, ind_radial, :] += ( - # np.power(self.orientation_shell_radii[ind_radial], radial_power) - # * np.power(self.struct_factors_int[a1], intensity_power) + # np.power(self.orientation_shell_radii[ind_radial], power_radial) + # * np.power(self.struct_factors_int[a1], power_intensity) # * np.maximum( # 1 # - np.sqrt( @@ -1034,7 +1038,11 @@ def match_single_pattern( if np.any(sub): im_polar[ind_radial, :] = np.sum( - np.maximum( + np.power( + np.maximum(intensity[sub, None], 0.0), + self.orientation_power_intensity_experiment, + ) + * np.maximum( 1 - np.sqrt( dqr[sub, None] ** 2 @@ -1058,10 +1066,10 @@ def match_single_pattern( axis=0, ) # im_polar[ind_radial, :] = np.sum( - # np.power(radius, self.orientation_radial_power) + # np.power(radius, self.orientation_power_radial) # * np.power( # np.maximum(intensity[sub, None], 0.0), - # self.orientation_intensity_power, + # self.orientation_power_intensity, # ) # * np.maximum( # 1 From 8792a7249485b2e04e1b08f3ee309799fa79a21e Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 8 May 2024 11:35:25 -0700 Subject: [PATCH 31/64] adding normalization testing --- py4DSTEM/process/diffraction/crystal_ACOM.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index b1f64aa58..0c8776010 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -834,6 +834,8 @@ def orientation_plan( # ) # ) + # normalization + # self.orientation_ref[a0, :, :] -= np.mean(self.orientation_ref[a0, :, :]) orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) if orientation_ref_norm > 0: self.orientation_ref[a0, :, :] /= orientation_ref_norm @@ -1095,6 +1097,10 @@ def match_single_pattern( # axis=0, # ) + + # # normalization + # self.im_polar -= np.mean(im_polar) + # Determine the RMS signal from im_polar for the first match. # Note that we use scaling slightly below RMS so that following matches # don't have higher correlating scores than previous matches. From 04bbbaf1faf2a2313b9db99f774a073896ee51c4 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 9 May 2024 09:32:40 -0700 Subject: [PATCH 32/64] Making kernel a Gaussian function instead of cone --- py4DSTEM/process/diffraction/crystal_ACOM.py | 64 ++++++++++++++------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 0c8776010..3acce396c 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1039,34 +1039,62 @@ def match_single_pattern( sub = dqr < self.orientation_kernel_size if np.any(sub): - im_polar[ind_radial, :] = np.sum( + + dist = im_polar[ind_radial, :] = np.sum( np.power( np.maximum(intensity[sub, None], 0.0), self.orientation_power_intensity_experiment, ) - * np.maximum( - 1 - - np.sqrt( - dqr[sub, None] ** 2 - + ( - ( - np.mod( - self.orientation_gamma[None, :] - - qphi[sub, None] - + np.pi, - 2 * np.pi, - ) - - np.pi + * np.exp( + (dqr[sub, None] ** 2 + + ( + ( + np.mod( + self.orientation_gamma[None, :] + - qphi[sub, None] + + np.pi, + 2 * np.pi, ) - * radius + - np.pi ) - ** 2 + * radius ) - / self.orientation_kernel_size, - 0, + ** 2) + / (-2*self.orientation_kernel_size**2) ), axis=0, ) + + # im_polar[ind_radial, :] = np.sum( + # np.power( + # np.maximum(intensity[sub, None], 0.0), + # self.orientation_power_intensity_experiment, + # ) + # * np.maximum( + # 1 + # - np.sqrt( + # dqr[sub, None] ** 2 + # + ( + # ( + # np.mod( + # self.orientation_gamma[None, :] + # - qphi[sub, None] + # + np.pi, + # 2 * np.pi, + # ) + # - np.pi + # ) + # * radius + # ) + # ** 2 + # ) + # / self.orientation_kernel_size, + # 0, + # ), + # axis=0, + # ) + + # im_polar[ind_radial, :] = np.sum( # np.power(radius, self.orientation_power_radial) # * np.power( From e8a02ae20ae97515dd877dd5c918958eaed712e6 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 9 May 2024 14:59:58 -0700 Subject: [PATCH 33/64] tweaking multiple match finder --- py4DSTEM/process/diffraction/crystal_ACOM.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 3acce396c..be8ae1f50 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -116,6 +116,7 @@ def orientation_plan( # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) + self.orientation_sigma_excitation_error = sigma_excitation_error if precession_angle_degrees is None: self.orientation_precession_angle_degrees = None else: @@ -837,6 +838,7 @@ def orientation_plan( # normalization # self.orientation_ref[a0, :, :] -= np.mean(self.orientation_ref[a0, :, :]) orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) + # orientation_ref_norm = np.sum(self.orientation_ref[a0, :, :]) if orientation_ref_norm > 0: self.orientation_ref[a0, :, :] /= orientation_ref_norm @@ -1327,6 +1329,7 @@ def match_single_pattern( np.clip( np.sum( self.orientation_vecs + # self.orientation_vecs * np.array([1,-1,-1])[None,:] * self.orientation_vecs[inds_previous[a0], :], axis=1, ), @@ -1625,7 +1628,8 @@ def match_single_pattern( bragg_peaks_fit = self.generate_diffraction_pattern( orientation, ind_orientation=match_ind, - sigma_excitation_error=self.orientation_kernel_size, + sigma_excitation_error = self.orientation_sigma_excitation_error, + precession_angle_degrees = self.orientation_precession_angle_degrees, ) remove = np.zeros_like(qx, dtype="bool") From daae2258aaa9ae43f056356318fcba57fb5f2d4e Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 17 May 2024 10:28:04 -0700 Subject: [PATCH 34/64] fix typo --- py4DSTEM/process/diffraction/crystal_ACOM.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index be8ae1f50..74e842817 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1042,7 +1042,7 @@ def match_single_pattern( if np.any(sub): - dist = im_polar[ind_radial, :] = np.sum( + im_polar[ind_radial, :] = np.sum( np.power( np.maximum(intensity[sub, None], 0.0), self.orientation_power_intensity_experiment, @@ -1128,8 +1128,8 @@ def match_single_pattern( # ) - # # normalization - # self.im_polar -= np.mean(im_polar) + # normalization + # im_polar -= np.mean(im_polar) # Determine the RMS signal from im_polar for the first match. # Note that we use scaling slightly below RMS so that following matches From 4734fc9286a25090460ca6cebe639171f5c547e7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 10:08:42 -0700 Subject: [PATCH 35/64] updating README, now that tutorials have been updated --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3fe6cc745..d62845d4c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ > :warning: **py4DSTEM version 0.14 update** :warning: Warning: this is a major update and we expect some workflows to break. You can still install previous versions of py4DSTEM [as discussed here](#legacyinstall) -> :warning: **Phase retrieval refactor version 0.14.9** :warning: Warning: The phase-retrieval modules in py4DSTEM (DPC, parallax, and ptychography) underwent a major refactor in version 0.14.9 and as such older tutorial notebooks will not work as expected. Notably, class names have been pruned to remove the trailing "Reconstruction" (`DPCReconstruction` -> `DPC` etc.), and regularization functions have dropped the `_iter` suffix (and are instead specified as boolean flags). We are working on updating the tutorial notebooks to reflect these changes. In the meantime, there's some more information in the relevant pull request [here](https://github.com/py4dstem/py4DSTEM/pull/597#issuecomment-1890325568). +> :warning: **Phase retrieval refactor version 0.14.9** :warning: Warning: The phase-retrieval modules in py4DSTEM (DPC, parallax, and ptychography) underwent a major refactor in version 0.14.9 and as such older tutorial notebooks will not work as expected. Notably, class names have been pruned to remove the trailing "Reconstruction" (`DPCReconstruction` -> `DPC` etc.), and regularization functions have dropped the `_iter` suffix (and are instead specified as boolean flags). See the [updated tutorials](https://github.com/py4dstem/py4DSTEM_tutorials) for more information. ![py4DSTEM logo](/images/py4DSTEM_logo.png) From a87ab8591f6c5cebaa8a4e8e6217b2d5e058dc1b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 17:38:12 -0700 Subject: [PATCH 36/64] better position initialization , single-slice --- py4DSTEM/process/phase/phase_base_class.py | 38 ++++++++----------- .../process/phase/ptychographic_methods.py | 20 +++++++--- .../process/phase/singleslice_ptychography.py | 21 +++++++--- 3 files changed, 46 insertions(+), 33 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 22fda4ac9..742013f1e 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1864,42 +1864,34 @@ def _calculate_scan_positions_in_pixels( else: raise ValueError() - if transpose: - x = (x - np.ptp(x) / 2) / sampling[1] - y = (y - np.ptp(y) / 2) / sampling[0] - else: - x = (x - np.ptp(x) / 2) / sampling[0] - y = (y - np.ptp(y) / 2) / sampling[1] x, y = np.meshgrid(x, y, indexing="ij") if positions_offset_ang is not None: - if transpose: - x += positions_offset_ang[0] / sampling[1] - y += positions_offset_ang[1] / sampling[0] - else: - x += positions_offset_ang[0] / sampling[0] - y += positions_offset_ang[1] / sampling[1] + x += positions_offset_ang[0] + y += positions_offset_ang[1] if positions_mask is not None: x = x[positions_mask] y = y[positions_mask] - else: - positions -= np.mean(positions, axis=0) - x = positions[:, 0] / sampling[1] - y = positions[:, 1] / sampling[0] + + positions = np.stack((x.ravel(), y.ravel()), axis=-1) if rotation_angle is not None: - x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin( - rotation_angle - ) + y * np.cos(rotation_angle) + tf = AffineTransform(angle=rotation_angle) + positions = tf(positions, positions.mean(0)) if transpose: - positions = np.array([y.ravel(), x.ravel()]).T - else: - positions = np.array([x.ravel(), y.ravel()]).T + positions = np.flip(positions, 1) + sampling = sampling[::-1] + + # ensure positive + positions -= np.min(positions, axis=0).clip(-np.inf, 0) - positions -= np.min(positions, axis=0) + # finally, switch to pixels + positions[:, 0] /= sampling[0] + positions[:, 1] /= sampling[1] + # top-left padding if object_padding_px is None: float_padding = region_of_interest_shape / 2 object_padding_px = (float_padding, float_padding) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 283ddb1ba..25a7f7785 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -48,14 +48,24 @@ def _initialize_object( xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if object_type == "potential": _object = xp.zeros((p, q), dtype=xp.float32) elif object_type == "complex": diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index cc7a5865d..b220ba741 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -88,6 +88,9 @@ class SingleslicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A verbose: bool, optional @@ -124,6 +127,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", @@ -177,6 +181,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -206,6 +211,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -266,6 +272,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -443,15 +451,18 @@ def preprocess( self._object_shape = self._object.shape - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() From bb8bdba03b978160910e698aadbb17b1e547a4b4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 18:04:13 -0700 Subject: [PATCH 37/64] add mixed, multislice better positions --- .../mixedstate_multislice_ptychography.py | 21 +++++++--- .../process/phase/mixedstate_ptychography.py | 21 +++++++--- .../process/phase/multislice_ptychography.py | 21 +++++++--- .../process/phase/ptychographic_methods.py | 40 ++++++++++++++----- 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index b22d0a0bb..d82a37eb4 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -107,6 +107,9 @@ class MixedstateMultislicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A theta_x: float @@ -159,6 +162,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, theta_x: float = 0, theta_y: float = 0, @@ -245,6 +249,7 @@ def __init__( self._object_type = object_type self._positions_mask = positions_mask self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._verbose = verbose self._preprocessed = False @@ -278,6 +283,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -348,6 +354,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -524,15 +532,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 436338555..7bbadf114 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -96,6 +96,9 @@ class MixedstatePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A positions_mask: np.ndarray, optional @@ -127,6 +130,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_type: str = "complex", positions_mask: np.ndarray = None, @@ -194,6 +198,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -224,6 +229,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -294,6 +300,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -469,15 +477,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 6de8e7970..87e3c1fe4 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -99,6 +99,9 @@ class MultislicePtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A theta_x: float @@ -149,6 +152,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, theta_x: float = None, theta_y: float = None, @@ -220,6 +224,7 @@ def __init__( self._object_type = object_type self._positions_mask = positions_mask self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._verbose = verbose self._preprocessed = False @@ -252,6 +257,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -322,6 +328,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -498,15 +506,18 @@ def preprocess( self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] - # center probe positions self._positions_px = xp_storage.asarray( self._positions_px, dtype=xp_storage.float32 ) self._positions_px_initial_com = self._positions_px.mean(0) - self._positions_px -= ( - self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 - ) - self._positions_px_initial_com = self._positions_px.mean(0) + + # center probe positions + if center_positions_in_fov: + self._positions_px -= ( + self._positions_px_initial_com + - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 25a7f7785..b9eae9385 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -412,14 +412,24 @@ def _initialize_object( xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if object_type == "potential": _object = xp.zeros((num_slices, p, q), dtype=xp.float32) elif object_type == "complex": @@ -868,14 +878,24 @@ def _initialize_object( # explicit read-only self attributes up-front xp = self._xp object_padding_px = self._object_padding_px + object_fov_ang = self._object_fov_ang region_of_interest_shape = self._region_of_interest_shape if initial_object is None: - pad_x = object_padding_px[0][1] - pad_y = object_padding_px[1][1] - p, q = np.round(np.max(positions_px, axis=0)) - p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") - q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_fov_ang is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype( + "int" + ) + else: + p, q = np.ceil( + np.array(object_fov_ang) / np.array(self.sampling) + ).astype("int") if main_tilt_axis == "vertical": _object = xp.zeros((q, p, q), dtype=xp.float32) From 92f35e1675b4ea033cf71a7662009ddaf9f565a2 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 28 May 2024 21:11:10 -0700 Subject: [PATCH 38/64] magnetic positions, absolute --- .../process/phase/magnetic_ptychography.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 19f306188..58c826b87 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -101,6 +101,9 @@ class MagneticPtychography( initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px positions_offset_ang: np.ndarray, optional Offset of positions in A verbose: bool, optional @@ -138,6 +141,7 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, positions_offset_ang: np.ndarray = None, object_type: str = "complex", verbose: bool = True, @@ -189,6 +193,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang self._positions_mask = positions_mask self._verbose = verbose self._preprocessed = False @@ -219,6 +224,7 @@ def preprocess( progress_bar: bool = True, object_fov_mask: np.ndarray = True, crop_patterns: bool = False, + center_positions_in_fov: bool = True, store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, @@ -286,6 +292,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. store_initial_arrays: bool If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional @@ -610,14 +618,17 @@ def preprocess( self._positions_px_all, dtype=xp_storage.float32 ) - for index in range(self._num_measurements): - idx_start = self._cum_probes_per_measurement[index] - idx_end = self._cum_probes_per_measurement[index + 1] + if center_positions_in_fov: + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] - positions_px = self._positions_px_all[idx_start:idx_end] - positions_px_com = positions_px.mean(0) - positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 - self._positions_px_all[idx_start:idx_end] = positions_px.copy() + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= ( + positions_px_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_all[idx_start:idx_end] = positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() From c1557bb50724e7c84fabb2f258376ec2b47ed6f5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 29 May 2024 17:17:45 -0700 Subject: [PATCH 39/64] add xmcd class --- py4DSTEM/process/phase/__init__.py | 1 + .../phase/xray_magnetic_ptychography.py | 1849 +++++++++++++++++ 2 files changed, 1850 insertions(+) create mode 100644 py4DSTEM/process/phase/xray_magnetic_ptychography.py diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index ecfeaa1d2..500a7cfea 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -12,5 +12,6 @@ from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer +from py4DSTEM.process.phase.xray_magnetic_ptychography import XRayMagneticPtychography # fmt: on diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py new file mode 100644 index 000000000..77e1cb835 --- /dev/null +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -0,0 +1,1849 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely x-ray magnetic ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class XRayMagneticPtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Iterative X-Ray Magnetic Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed electrostatic dimensions : (Px,Py) + Reconstructed magnetic dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + datacube: Sequence[DataCube] + Tuple of input 4D diffraction pattern intensities + energy: float + The electron energy of the wave functions in eV + magnetic_contribution_sign: str, optional + One of '-+', '-0+', '0+' + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad objects with + If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (2,Px,Py) + If None, initialized to 1.0j for complex objects and 0.0 for potential objects + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + object_fov_ang: Tuple[int,int], optional + Fixed object field of view in Å. If None, the fov is initialized using the + probe positions and object_padding_px + positions_offset_ang: np.ndarray, optional + Offset of positions in A + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_magnetic_contribution_sign",) + + def __init__( + self, + energy: float, + datacube: Sequence[DataCube] = None, + magnetic_contribution_sign: str = "-+", + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_fov_ang: Tuple[float, float] = None, + positions_offset_ang: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "xray_magnetic_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "complex": + raise NotImplementedError() + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._object_fov_ang = object_fov_ang + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._magnetic_contribution_sign = magnetic_contribution_sign + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: Sequence[np.ndarray] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + progress_bar: bool = True, + object_fov_mask: np.ndarray = True, + crop_patterns: bool = False, + center_positions_in_fov: bool = True, + store_initial_arrays: bool = True, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + center_positions_in_fov: bool + If True (default), probe positions are centered in the fov. + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._magnetic_contribution_sign == "-+": + self._recon_mode = 0 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic contribution sign in first meaurement assumed to be negative.\n" + "Magnetic contribution sign in second meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "-0+": + self._recon_mode = 1 + self._num_measurements = 3 + magnetic_contribution_msg = ( + "Magnetic contribution sign in first meaurement assumed to be negative.\n" + "Magnetic contribution assumed to be zero in second meaurement.\n" + "Magnetic contribution sign in third meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "0+": + self._recon_mode = 2 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic contribution assumed to be zero in first meaurement.\n" + "Magnetic contribution sign in second meaurement assumed to be positive." + ) + else: + raise ValueError( + f"magnetic_contribution_sign must be either '-+', '-0+', or '0+', not {self._magnetic_contribution_sign}" + ) + + if self._verbose: + warnings.warn( + magnetic_contribution_msg, + UserWarning, + ) + + if len(self._datacube) != self._num_measurements: + raise ValueError( + f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." + ) + + dc_shapes = [dc.shape for dc in self._datacube] + if dc_shapes.count(dc_shapes[0]) != self._num_measurements: + raise ValueError("datacube intensities must be the same size.") + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if force_com_measured is None: + force_com_measured = [None] * self._num_measurements + + if self._scan_positions is None: + self._scan_positions = [None] * self._num_measurements + + if self._positions_offset_ang is None: + self._positions_offset_ang = [None] * self._num_measurements + + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + + if progress_bar: + # turn off verbosity to play nice with tqdm + verbose = self._verbose + self._verbose = False + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="measurement", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first measurement + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured[index], + ) + + # estimate rotation / transpose using first measurement + if index == 0: + # silence warnings to play nice with progress bar + verbose = self._verbose + self._verbose = False + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + _com_x, + _com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + com_measured_x, + com_measured_y, + com_normalized_x, + com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=False, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + self._verbose = verbose + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + self._positions_offset_ang[index], + ) + + if progress_bar: + # reset verbosity + self._verbose = verbose + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + ) + + if self._object is None: + self._object = xp.full((2,) + obj.shape, obj) + else: + self._object = obj + + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + if center_positions_in_fov: + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= ( + positions_px_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + # initialize object_fov_mask + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + probe_overlap, + extent=extent, + cmap="gray", + ) + ax2.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_xlim((extent[0], extent[1])) + ax2.set_ylim((extent[2], extent[3])) + ax2.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = xp.empty( + (self._num_measurements,) + shifted_probes.shape, dtype=current_object.dtype + ) + object_patches[0] = current_object[ + 0, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + object_patches[1] = current_object[ + 1, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + overlap_base = shifted_probes * object_patches[0] + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + overlap = overlap_base * xp.exp(-1j * object_patches[1]) + case (0, 1) | (1, 2) | (2, 1): # forward + overlap = overlap_base * xp.exp(1j * object_patches[1]) + case (1, 1) | (2, 0): # neutral + overlap = overlap_base + case _: + raise ValueError() + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_conj = xp.conj(shifted_probes) # P* + electrostatic_conj = xp.conj(object_patches[0]) # C* = exp(-i c) + + probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) + probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( + probe_electrostatic_abs**2, + positions_px, + ) + probe_electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 + + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 + ) + + probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + + if not fix_probe: + electrostatic_magnetic_abs = xp.abs( + object_patches[0] * xp.exp(1j * object_patches[1]) + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_magnetic_normalization)) + ** 2 + ) + + if self._recon_mode > 0: + electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_normalization = xp.sum( + electrostatic_abs**2, + axis=0, + ) + electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + ) + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + + magnetic_exp = xp.exp(1j * xp.conj(object_patches[1])) + + # P* exp(i M*) + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * magnetic_exp * exit_waves, + positions_px, + ) + + # i exp(i M*) C* P* + magnetic_update = self._sum_overlapping_patches_bincounts( + 1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # exp(i M*) C* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * magnetic_exp * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (0, 1) | (1, 2) | (2, 1): # forward + + magnetic_exp = xp.exp(-1j * xp.conj(object_patches[1])) + + # P* exp(-i M*) + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * magnetic_exp * exit_waves, + positions_px, + ) + + # -i exp (-i M*) P* + magnetic_update = self._sum_overlapping_patches_bincounts( + -1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # exp(-i M*) C* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * magnetic_exp * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (1, 1) | (2, 0): # neutral + probe_abs = xp.abs(shifted_probes) + probe_normalization = self._sum_overlapping_patches_bincounts( + probe_abs**2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + # P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_normalization + ) + + if not fix_probe: + # V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_normalization + ) + + case _: + raise ValueError() + + return current_object, current_probe + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma_e, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_lowpass_m, + q_highpass_e, + q_highpass_m, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + **kwargs, + ): + """MagneticObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object[0] = self._object_gaussian_constraint( + current_object[0], gaussian_filter_sigma_e, False + ) + current_object[1] = self._object_gaussian_constraint( + current_object[1], gaussian_filter_sigma_m, False + ) + if butterworth_filter: + current_object[0] = self._object_butterworth_constraint( + current_object[0], + q_lowpass_e, + q_highpass_e, + butterworth_order, + ) + current_object[1] = self._object_butterworth_constraint( + current_object[1], + q_lowpass_m, + q_highpass_m, + butterworth_order, + ) + if tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], tv_denoise_weight, tv_denoise_inner_iter + ) + + # amplitude threshold + # current_object[0] = self._object_threshold_constraint( + # current_object[0], False + # ) + + return current_object + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma_e: float = None, + gaussian_filter_sigma_m: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass_e: float = None, + q_lowpass_m: float = None, + q_highpass_e: float = None, + q_highpass_m: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + detector_fourier_mask: np.ndarray = None, + store_iterations: bool = False, + collective_measurement_updates: bool = True, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma_e: float + Standard deviation of gaussian kernel for electrostatic object in A + gaussian_filter_sigma_m: float + Standard deviation of gaussian kernel for magnetic object in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass_e: float + Cut-off frequency in A^-1 for low-pass filtering electrostatic object + q_lowpass_m: float + Cut-off frequency in A^-1 for low-pass filtering magnetic object + q_highpass_e: float + Cut-off frequency in A^-1 for high-pass filtering electrostatic object + q_highpass_m: float + Cut-off frequency in A^-1 for high-pass filtering magnetic object + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + collective_measurement_updates: bool + if True perform collective updates for all measurements + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if use_projection_scheme: + raise NotImplementedError( + "Magnetic ptychography is currently only implemented for gradient descent." + ) + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if object_type is not None: + self._switch_object_type(object_type) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is not None: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + if gaussian_filter_sigma_m is None: + gaussian_filter_sigma_m = gaussian_filter_sigma_e + + if q_lowpass_m is None: + q_lowpass_m = q_lowpass_e + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + # randomize + measurement_indices = np.arange(self._num_measurements) + np.random.shuffle(measurement_indices) + + for measurement_index in measurement_indices: + self._active_measurement_index = measurement_index + + measurement_error = 0.0 + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme=use_projection_scheme, + projection_a=projection_a, + projection_b=projection_b, + projection_c=projection_c, + ) + + # adjoint operator + object_update, _probe = self._adjoint( + self._object.copy(), + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + object_update -= self._object + + # position correction + if not fix_positions and a0 > 0: + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + ) + + measurement_error += batch_error + + if collective_measurement_updates: + collective_object += object_update + else: + self._object += object_update + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (12, 5)) + cmap_real = kwargs.pop("cmap_real", "PiYG") + cmap_imag = kwargs.pop("cmap_imag", "PuOr") + chroma_boost = kwargs.pop("chroma_boost", 1) + + # get scaled arrays + probe = self._return_single_probe() + obj = self.object_cropped + + _, _, _vmax_real = return_scaled_histogram_ordering(obj[1].real) + vmin_real = kwargs.pop("vmin_real", -_vmax_real) + vmax_real = kwargs.pop("vmax_real", _vmax_real) + + _, _, _vmax_imag = return_scaled_histogram_ordering(obj[1].imag) + vmin_imag = kwargs.pop("vmin_imag", -_vmax_imag) + vmax_imag = kwargs.pop("vmax_imag", _vmax_imag) + + extent = [ + 0, + self.sampling[1] * obj.shape[2], + self.sampling[0] * obj.shape[1], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object_real + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[1].real, + extent=extent, + cmap=cmap_real, + vmin=vmin_real, + vmax=vmax_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_imag + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1].imag, + extent=extent, + cmap=cmap_imag, + vmin=vmin_imag, + vmax=vmax_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imaginary magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + ax = fig.add_subplot(spec[0, 2]) + if plot_fourier_probe: + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title("Reconstructed probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + # Object_real + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[1].real, + extent=extent, + cmap=cmap_real, + vmin=vmin_real, + vmax=vmax_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_imag + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1].imag, + extent=extent, + cmap=cmap_imag, + vmin=vmin_imag, + vmax=vmax_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imaginary magnetic refractive index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[1, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + @property + def object_cropped(self): + """Cropped and rotated object""" + avg_pos = self._return_average_positions() + cropped_e = self._crop_rotate_object_fov(self._object[0], positions_px=avg_pos) + cropped_m = self._crop_rotate_object_fov(self._object[1], positions_px=avg_pos) + + return np.array([cropped_e, cropped_m]) From 7ed3d2e680b71969ab5ff27025b7bb4e8f5c1283 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 29 May 2024 21:06:12 -0700 Subject: [PATCH 40/64] xmcd init, adjoint typos --- .../phase/xray_magnetic_ptychography.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 77e1cb835..a40dd0f8a 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -602,7 +602,7 @@ def preprocess( ) if self._object is None: - self._object = xp.full((2,) + obj.shape, obj) + self._object = xp.stack((obj, xp.zeros_like(obj)), 0) else: self._object = obj @@ -819,9 +819,9 @@ def _overlap_projection( match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse - overlap = overlap_base * xp.exp(-1j * object_patches[1]) + overlap = overlap_base * xp.exp(-1.0j * object_patches[1]) case (0, 1) | (1, 2) | (2, 1): # forward - overlap = overlap_base * xp.exp(1j * object_patches[1]) + overlap = overlap_base * xp.exp(1.0j * object_patches[1]) case (1, 1) | (2, 0): # neutral overlap = overlap_base case _: @@ -874,7 +874,7 @@ def _gradient_descent_adjoint( xp = self._xp probe_conj = xp.conj(shifted_probes) # P* - electrostatic_conj = xp.conj(object_patches[0]) # C* = exp(-i c) + electrostatic_conj = xp.conj(object_patches[0]) # C* probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( @@ -887,7 +887,7 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 ) - probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) + probe_magnetic_abs = xp.abs(shifted_probes * xp.exp(1.0j * object_patches[1])) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( probe_magnetic_abs**2, positions_px, @@ -900,7 +900,7 @@ def _gradient_descent_adjoint( if not fix_probe: electrostatic_magnetic_abs = xp.abs( - object_patches[0] * xp.exp(1j * object_patches[1]) + object_patches[0] * xp.exp(1.0j * object_patches[1]) ) electrostatic_magnetic_normalization = xp.sum( electrostatic_magnetic_abs**2, @@ -928,7 +928,7 @@ def _gradient_descent_adjoint( match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse - magnetic_exp = xp.exp(1j * xp.conj(object_patches[1])) + magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) # P* exp(i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( @@ -938,7 +938,7 @@ def _gradient_descent_adjoint( # i exp(i M*) C* P* magnetic_update = self._sum_overlapping_patches_bincounts( - 1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + 1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, positions_px, ) @@ -961,7 +961,7 @@ def _gradient_descent_adjoint( case (0, 1) | (1, 2) | (2, 1): # forward - magnetic_exp = xp.exp(-1j * xp.conj(object_patches[1])) + magnetic_exp = xp.exp(-1.0j * xp.conj(object_patches[1])) # P* exp(-i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( @@ -971,7 +971,7 @@ def _gradient_descent_adjoint( # -i exp (-i M*) P* magnetic_update = self._sum_overlapping_patches_bincounts( - -1j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + -1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, positions_px, ) From 936df0e1ff9f04a65631fdae92c236ed7950c67e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 29 May 2024 22:56:06 -0700 Subject: [PATCH 41/64] more careful normalizations --- .../phase/xray_magnetic_ptychography.py | 114 ++++++++++++------ 1 file changed, 77 insertions(+), 37 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index a40dd0f8a..48ec32a49 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -887,49 +887,22 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 ) - probe_magnetic_abs = xp.abs(shifted_probes * xp.exp(1.0j * object_patches[1])) - probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2, - positions_px, - ) - probe_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 - + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 - ) + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - object_patches[0] * xp.exp(1.0j * object_patches[1]) - ) - electrostatic_magnetic_normalization = xp.sum( - electrostatic_magnetic_abs**2, - axis=0, - ) - electrostatic_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_magnetic_normalization)) - ** 2 - ) + magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) - if self._recon_mode > 0: - electrostatic_abs = xp.abs(object_patches[0]) - electrostatic_normalization = xp.sum( - electrostatic_abs**2, - axis=0, + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, ) - electrostatic_normalization = 1 / xp.sqrt( + probe_magnetic_normalization = 1 / xp.sqrt( 1e-16 - + ((1 - normalization_min) * electrostatic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 ) - match (self._recon_mode, self._active_measurement_index): - case (0, 0) | (1, 0): # reverse - - magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) - # P* exp(i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( probe_conj * magnetic_exp * exit_waves, @@ -950,6 +923,28 @@ def _gradient_descent_adjoint( ) if not fix_probe: + + electrostatic_magnetic_abs = xp.abs( + object_patches[0] * magnetic_exp + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ( + (1 - normalization_min) + * electrostatic_magnetic_normalization + ) + ** 2 + + ( + normalization_min + * xp.max(electrostatic_magnetic_normalization) + ) + ** 2 + ) + # exp(i M*) C* current_probe += step_size * ( xp.sum( @@ -963,6 +958,17 @@ def _gradient_descent_adjoint( magnetic_exp = xp.exp(-1.0j * xp.conj(object_patches[1])) + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + # P* exp(-i M*) electrostatic_update = self._sum_overlapping_patches_bincounts( probe_conj * magnetic_exp * exit_waves, @@ -983,6 +989,28 @@ def _gradient_descent_adjoint( ) if not fix_probe: + + electrostatic_magnetic_abs = xp.abs( + object_patches[0] * magnetic_exp + ) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ( + (1 - normalization_min) + * electrostatic_magnetic_normalization + ) + ** 2 + + ( + normalization_min + * xp.max(electrostatic_magnetic_normalization) + ) + ** 2 + ) + # exp(-i M*) C* current_probe += step_size * ( xp.sum( @@ -1015,6 +1043,18 @@ def _gradient_descent_adjoint( ) if not fix_probe: + + electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_normalization = xp.sum( + electrostatic_abs**2, + axis=0, + ) + electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + ) + # V* current_probe += step_size * ( xp.sum( From 8095fa06368840f79c614f996d8b9a5411068fdd Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 30 May 2024 17:42:05 -0700 Subject: [PATCH 42/64] switched to real-imag electronic contribution --- .../phase/xray_magnetic_ptychography.py | 77 +++++++++---------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 48ec32a49..1912c93b1 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -602,7 +602,8 @@ def preprocess( ) if self._object is None: - self._object = xp.stack((obj, xp.zeros_like(obj)), 0) + # complex zeros instead of ones, since we store pre-exponential terms + self._object = xp.zeros((2,) + obj.shape, dtype=obj.dtype) else: self._object = obj @@ -815,7 +816,7 @@ def _overlap_projection( 1, vectorized_patch_indices_row, vectorized_patch_indices_col ] - overlap_base = shifted_probes * object_patches[0] + overlap_base = shifted_probes * xp.exp(1.0j * object_patches[0]) match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse @@ -874,9 +875,9 @@ def _gradient_descent_adjoint( xp = self._xp probe_conj = xp.conj(shifted_probes) # P* - electrostatic_conj = xp.conj(object_patches[0]) # C* + electrostatic_conj = xp.exp(-1.0j * xp.conj(object_patches[0])) # exp[-i c] - probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) + probe_electrostatic_abs = xp.abs(probe_conj * electrostatic_conj) probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( probe_electrostatic_abs**2, positions_px, @@ -890,9 +891,9 @@ def _gradient_descent_adjoint( match (self._recon_mode, self._active_measurement_index): case (0, 0) | (1, 0): # reverse - magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1])) + magnetic_conj = xp.exp(1.0j * xp.conj(object_patches[1])) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( probe_magnetic_abs**2, positions_px, @@ -903,15 +904,19 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 ) - # P* exp(i M*) + # - i * exp(i m*) * exp(-i c*) * P electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_exp * exit_waves, + -1.0j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves, positions_px, ) - # i exp(i M*) C* P* + # i * exp(i m*) * exp(-i c*) * P magnetic_update = self._sum_overlapping_patches_bincounts( - 1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + 1.0j * magnetic_conj * electrostatic_conj * probe_conj * exit_waves, positions_px, ) @@ -925,7 +930,7 @@ def _gradient_descent_adjoint( if not fix_probe: electrostatic_magnetic_abs = xp.abs( - object_patches[0] * magnetic_exp + electrostatic_conj * magnetic_conj ) electrostatic_magnetic_normalization = xp.sum( electrostatic_magnetic_abs**2, @@ -945,10 +950,10 @@ def _gradient_descent_adjoint( ** 2 ) - # exp(i M*) C* + # exp(i m*) * exp(-i c*) current_probe += step_size * ( xp.sum( - electrostatic_conj * magnetic_exp * exit_waves, + magnetic_conj * electrostatic_conj * exit_waves, axis=0, ) * electrostatic_magnetic_normalization @@ -956,9 +961,9 @@ def _gradient_descent_adjoint( case (0, 1) | (1, 2) | (2, 1): # forward - magnetic_exp = xp.exp(-1.0j * xp.conj(object_patches[1])) + magnetic_conj = xp.exp(-1.0j * xp.conj(object_patches[1])) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp) + probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( probe_magnetic_abs**2, positions_px, @@ -969,29 +974,25 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 ) - # P* exp(-i M*) - electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_exp * exit_waves, - positions_px, - ) - - # -i exp (-i M*) P* - magnetic_update = self._sum_overlapping_patches_bincounts( - -1.0j * magnetic_exp * probe_conj * electrostatic_conj * exit_waves, + # - i * exp(-i m*) * exp(-i c*) * P + update = self._sum_overlapping_patches_bincounts( + -1.0j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves, positions_px, ) - current_object[0] += ( - step_size * electrostatic_update * probe_magnetic_normalization - ) + current_object[0] += step_size * update * probe_magnetic_normalization current_object[1] += ( - step_size * magnetic_update * probe_electrostatic_normalization + step_size * update * probe_electrostatic_normalization ) if not fix_probe: electrostatic_magnetic_abs = xp.abs( - object_patches[0] * magnetic_exp + electrostatic_conj * magnetic_conj ) electrostatic_magnetic_normalization = xp.sum( electrostatic_magnetic_abs**2, @@ -1011,16 +1012,17 @@ def _gradient_descent_adjoint( ** 2 ) - # exp(-i M*) C* + # exp(i m*) * exp(-i c*) current_probe += step_size * ( xp.sum( - electrostatic_conj * magnetic_exp * exit_waves, + magnetic_conj * electrostatic_conj * exit_waves, axis=0, ) * electrostatic_magnetic_normalization ) case (1, 1) | (2, 0): # neutral + probe_abs = xp.abs(shifted_probes) probe_normalization = self._sum_overlapping_patches_bincounts( probe_abs**2, @@ -1032,9 +1034,9 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(probe_normalization)) ** 2 ) - # P* + # -i exp(-i c*) * P* electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves, + -1.0j * electrostatic_conj * probe_conj * exit_waves, positions_px, ) @@ -1044,7 +1046,7 @@ def _gradient_descent_adjoint( if not fix_probe: - electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_abs = xp.abs(electrostatic_conj) electrostatic_normalization = xp.sum( electrostatic_abs**2, axis=0, @@ -1055,7 +1057,7 @@ def _gradient_descent_adjoint( + (normalization_min * xp.max(electrostatic_normalization)) ** 2 ) - # V* + # exp(-i c*) current_probe += step_size * ( xp.sum( electrostatic_conj * exit_waves, @@ -1114,11 +1116,6 @@ def _object_constraints( current_object[0], tv_denoise_weight, tv_denoise_inner_iter ) - # amplitude threshold - # current_object[0] = self._object_threshold_constraint( - # current_object[0], False - # ) - return current_object def reconstruct( From e6c0dba48d4cdc0605695a405dfb1746e7db09ef Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 30 May 2024 19:17:39 -0700 Subject: [PATCH 43/64] better xray visualizations --- .../phase/xray_magnetic_ptychography.py | 283 +++++++++++------- 1 file changed, 182 insertions(+), 101 deletions(-) diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 1912c93b1..91c0bbfaa 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -1658,8 +1658,6 @@ def _visualize_last_iteration( If true, displays a colorbar plot_probe: bool, optional If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed remove_initial_probe_aberrations: bool, optional If true, when plotting fourier probe, removes initial probe to visualize changes @@ -1667,22 +1665,34 @@ def _visualize_last_iteration( asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (12, 5)) - cmap_real = kwargs.pop("cmap_real", "PiYG") - cmap_imag = kwargs.pop("cmap_imag", "PuOr") + figsize = kwargs.pop("figsize", (12, 8)) + cmap_e_real = kwargs.pop("cmap_e_real", "cividis") + cmap_e_imag = kwargs.pop("cmap_e_imag", "magma") + cmap_m_real = kwargs.pop("cmap_m_real", "PuOr") + cmap_m_imag = kwargs.pop("cmap_m_imag", "PiYG") chroma_boost = kwargs.pop("chroma_boost", 1) # get scaled arrays - probe = self._return_single_probe() obj = self.object_cropped - _, _, _vmax_real = return_scaled_histogram_ordering(obj[1].real) - vmin_real = kwargs.pop("vmin_real", -_vmax_real) - vmax_real = kwargs.pop("vmax_real", _vmax_real) + vmin_e_real = kwargs.pop("vmin_e_real", None) + vmax_e_real = kwargs.pop("vmax_e_real", None) + vmin_e_imag = kwargs.pop("vmin_e_imag", None) + vmax_e_imag = kwargs.pop("vmax_e_imag", None) + _, vmin_e_real, vmax_e_real = return_scaled_histogram_ordering( + obj[0].real, vmin_e_real, vmax_e_real + ) + _, vmin_e_imag, vmax_e_iamg = return_scaled_histogram_ordering( + obj[0].imag, vmin_e_imag, vmax_e_imag + ) + + _, _, _vmax_m_real = return_scaled_histogram_ordering(obj[1].real) + vmin_m_real = kwargs.pop("vmin_m_real", -_vmax_m_real) + vmax_m_real = kwargs.pop("vmax_m_real", _vmax_m_real) - _, _, _vmax_imag = return_scaled_histogram_ordering(obj[1].imag) - vmin_imag = kwargs.pop("vmin_imag", -_vmax_imag) - vmax_imag = kwargs.pop("vmax_imag", _vmax_imag) + _, _, _vmax_m_imag = return_scaled_histogram_ordering(obj[1].imag) + vmin_m_imag = kwargs.pop("vmin_m_imag", -_vmax_m_imag) + vmax_m_imag = kwargs.pop("vmax_m_imag", _vmax_m_imag) extent = [ 0, @@ -1698,7 +1708,6 @@ def _visualize_last_iteration( self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, ] - elif plot_probe: probe_extent = [ 0, @@ -1708,11 +1717,11 @@ def _visualize_last_iteration( ] if plot_convergence: - if plot_probe or plot_fourier_probe: + if plot_probe: spec = GridSpec( ncols=3, - nrows=2, - height_ratios=[4, 1], + nrows=3, + height_ratios=[4, 4, 1], hspace=0.15, width_ratios=[ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), @@ -1723,13 +1732,13 @@ def _visualize_last_iteration( ) else: - spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) + spec = GridSpec(ncols=2, nrows=3, height_ratios=[4, 4, 1], hspace=0.15) else: - if plot_probe or plot_fourier_probe: + if plot_probe: spec = GridSpec( ncols=3, - nrows=1, + nrows=2, width_ratios=[ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), @@ -1739,79 +1748,143 @@ def _visualize_last_iteration( ) else: - spec = GridSpec(ncols=2, nrows=1) + spec = GridSpec(ncols=2, nrows=2, wspace=0.35) if fig is None: fig = plt.figure(figsize=figsize) - if plot_probe or plot_fourier_probe: - # Object_real - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[1].real, - extent=extent, - cmap=cmap_real, - vmin=vmin_real, - vmax=vmax_real, - **kwargs, + # Electronic real + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0].real, + extent=extent, + cmap=cmap_e_real, + vmin=vmin_e_real, + vmax=vmax_e_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real elec. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Electronic imag + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[0].imag, + extent=extent, + cmap=cmap_e_imag, + vmin=vmin_e_imag, + vmax=vmax_e_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imag elec. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Magnetic real + ax = fig.add_subplot(spec[1, 0]) + im = ax.imshow( + obj[1].real, + extent=extent, + cmap=cmap_m_real, + vmin=vmin_m_real, + vmax=vmax_m_real, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Real mag. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Magnetic imag + ax = fig.add_subplot(spec[1, 1]) + im = ax.imshow( + obj[1].imag, + extent=extent, + cmap=cmap_m_imag, + vmin=vmin_m_imag, + vmax=vmax_m_imag, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Imag mag. optical index") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_fourier_probe: + # Fourier probe + intensities = self._return_probe_intensities(None) + titles = [ + f"{sign}ve Fourier probe: {ratio*100:.1f}%" + for sign, ratio in zip(self._magnetic_contribution_sign, intensities) + ] + ax = fig.add_subplot(spec[0, 2]) + + probe_fourier = asnumpy( + self._return_fourier_probe( + self._probes_all[0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Real magnetic refractive index") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + probe_array = Complex2RGB( + probe_fourier, + chroma_boost=chroma_boost, + ) + + ax.set_title(titles[0]) + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") - # Object_imag - ax = fig.add_subplot(spec[0, 1]) im = ax.imshow( - obj[1].imag, - extent=extent, - cmap=cmap_imag, - vmin=vmin_imag, - vmax=vmax_imag, - **kwargs, + probe_array, + extent=probe_extent, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Imaginary magnetic refractive index") if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - # Probe - ax = fig.add_subplot(spec[0, 2]) - if plot_fourier_probe: - probe = asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) + ax = fig.add_subplot(spec[1, 2]) - probe_array = Complex2RGB( - probe, - chroma_boost=chroma_boost, + probe_fourier = asnumpy( + self._return_fourier_probe( + self._probes_all[-1], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, ) + ) - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - asnumpy(self._return_centered_probe(probe)), - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") + probe_array = Complex2RGB( + probe_fourier, + chroma_boost=chroma_boost, + ) + + ax.set_title(titles[-1]) + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") im = ax.imshow( probe_array, @@ -1823,51 +1896,59 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - else: - # Object_real - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[1].real, - extent=extent, - cmap=cmap_real, - vmin=vmin_real, - vmax=vmax_real, - **kwargs, + elif plot_probe: + # Real probe + intensities = self._return_probe_intensities(None) + titles = [ + f"{sign}ve probe intensity: {ratio*100:.1f}%" + for sign, ratio in zip(self._magnetic_contribution_sign, intensities) + ] + ax = fig.add_subplot(spec[0, 2]) + + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(self._probes_all[0])), + power=2, + chroma_boost=chroma_boost, ) + ax.set_title(titles[0]) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Real magnetic refractive index") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - # Object_imag - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1].imag, - extent=extent, - cmap=cmap_imag, - vmin=vmin_imag, - vmax=vmax_imag, - **kwargs, + ax = fig.add_subplot(spec[1, 2]) + + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(self._probes_all[-1])), + power=2, + chroma_boost=chroma_boost, ) + ax.set_title(titles[-1]) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Imaginary magnetic refractive index") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) if plot_convergence and hasattr(self, "error_iterations"): errors = np.array(self.error_iterations) - ax = fig.add_subplot(spec[1, :]) + ax = fig.add_subplot(spec[2, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") ax.set_xlabel("Iteration number") From c964ab7773fe8e98c253f8ef2105fd956633efb4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 30 May 2024 19:17:53 -0700 Subject: [PATCH 44/64] cleanup magnetic viz --- .../process/phase/magnetic_ptychography.py | 139 ++++++------------ 1 file changed, 45 insertions(+), 94 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 58c826b87..aa89d09b3 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1796,55 +1796,55 @@ def _visualize_last_iteration( if fig is None: fig = plt.figure(figsize=figsize) - if plot_probe or plot_fourier_probe: - # Object_e - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[0], - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + # Object_e + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0], + extent=extent, + cmap=cmap_e, + vmin=vmin_e, + vmax=vmax_e, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") - # Object_m - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1], - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Electrostatic potential") + elif self._object_type == "complex": + ax.set_title("Electrostatic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_m + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1], + extent=extent, + cmap=cmap_m, + vmin=vmin_m, + vmax=vmax_m, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Magnetic potential") - elif self._object_type == "complex": - ax.set_title("Magnetic phase") + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + if plot_probe or plot_fourier_probe: # Probe ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: @@ -1883,55 +1883,6 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - else: - # Object_e - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - obj[0], - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Object_e - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - obj[1], - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - if self._object_type == "potential": - ax.set_title("Magnetic potential") - elif self._object_type == "complex": - ax.set_title("Magnetic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "error_iterations"): errors = np.array(self.error_iterations) From 6a48ab0cda125e400fd4386ad66cb1f3a91dd6c2 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 31 May 2024 11:14:21 -0400 Subject: [PATCH 45/64] small bug in dp_mask for gpu --- py4DSTEM/process/phase/phase_base_class.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 742013f1e..27476cb43 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -686,6 +686,7 @@ def _calculate_intensities_center_of_mass( # calculate CoM if dp_mask is not None: + dp_mask = copy_to_device(dp_mask, device) intensities_mask = intensities * dp_mask else: intensities_mask = intensities From d99138010b7bca543e9e60acbbab469e2336a93c Mon Sep 17 00:00:00 2001 From: Yael_Tsarfati Date: Wed, 5 Jun 2024 14:43:18 -0700 Subject: [PATCH 46/64] plot radial peaks v lines added --- py4DSTEM/process/polar/polar_peaks.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 535ae7143..73232fa8e 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -660,11 +660,28 @@ def plot_radial_peaks( qstep=None, label_y_axis=False, figsize=(8, 4), + v_lines=None, returnfig=False, ): """ Calculate and plot the total peak signal as a function of the radial coordinate. + q_pixel_units + If True, plot in reciprocal units instead of pixels. + qmin + The minimum q for plotting. + qmax + The maximum q for plotting. + qstep + The bin width. + label_y_axis + If True, label y axis. + figsize + Plot size. + v_lines: tuple + x coordinates for plotting vertical lines. + returnfig + If True, returns figure. """ # Get all peak data @@ -741,6 +758,17 @@ def plot_radial_peaks( if not label_y_axis: ax.tick_params(left=False, labelleft=False) + plt.tight_layout() + + if v_lines is not None: + y_min, y_max = ax.get_ylim() + + if np.isscalar(v_lines): + ax.vlines(v_lines, y_min, y_max, color="g") + else: + for a0 in range(len(v_lines)): + ax.vlines(v_lines[a0], y_min, y_max, color="g") + if returnfig: return fig, ax From 9973e208b2e3957f25166b1af1635329eb5b6330 Mon Sep 17 00:00:00 2001 From: Yael_Tsarfati Date: Wed, 5 Jun 2024 14:44:52 -0700 Subject: [PATCH 47/64] Typo --- py4DSTEM/process/polar/polar_peaks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 73232fa8e..cdf3cbca6 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -758,8 +758,6 @@ def plot_radial_peaks( if not label_y_axis: ax.tick_params(left=False, labelleft=False) - plt.tight_layout() - if v_lines is not None: y_min, y_max = ax.get_ylim() From 4a183dfb2c7bd868e2b060f73e71178d21d7309c Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 21 Jun 2024 13:11:15 -0700 Subject: [PATCH 48/64] Changing default settings, fiddling with plots --- py4DSTEM/process/diffraction/crystal.py | 2 +- py4DSTEM/process/diffraction/crystal_phase.py | 219 ++++++++++++++---- py4DSTEM/process/diffraction/crystal_viz.py | 44 +++- 3 files changed, 213 insertions(+), 52 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 9694111c9..ed4f9ca6b 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1368,7 +1368,7 @@ def excitation_errors( g, foil_normal=None, precession_angle_degrees=None, - precession_steps=180, + precession_steps=72, ): """ Calculate the excitation errors, assuming k0 = [0, 0, -1/lambda]. diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index f6d916a95..c18865071 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -78,16 +78,19 @@ def quantify_single_pattern( pointlistarray: PointListArray, xy_position = (0,0), corr_kernel_size = 0.04, - sigma_excitation_error = 0.02, - power_experiment = 0.25, - power_calculated = 0.25, - max_number_patterns = 3, + sigma_excitation_error: float = 0.02, + precession_angle_degrees = None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max = None, + max_number_patterns = 2, single_phase = False, - allow_strain = True, + allow_strain = False, strain_iterations = 3, strain_max = 0.02, include_false_positives = True, weight_false_positives = 1.0, + weight_unmatched_peaks = 1.0, plot_result = True, plot_only_nonzero_phases = True, plot_unmatched_peaks = False, @@ -102,10 +105,32 @@ def quantify_single_pattern( ): """ Quantify the phase for a single diffraction pattern. + + Parameters + ---------- + + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error: (float) + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees: (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + + power_radial: (float) + Power for scaling the correlation intensity as a function of the peak radius + power_intensity: (float) + Power for scaling the correlation intensity as a function of simulated peak intensity + power_intensity_experiment: (float): + Power for scaling the correlation intensity as a function of experimental peak intensity + k_max: (float) + Max k values included in fits, for both x and y directions. + """ + # tolerance - tol2 = 4e-4 + tol2 = 1e-6 # calibrations center = pointlistarray.calstate['center'] @@ -137,7 +162,16 @@ def quantify_single_pattern( pixel = pixel, rotate = rotate) # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() - keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + if k_max is None: + keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + else: + keep = np.logical_and.reduce(( + bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2, + np.abs(bragg_peaks.data["qx"]) < k_max, + np.abs(bragg_peaks.data["qy"]) < k_max, + )) + + # ind_center_beam = np.argmin( # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') @@ -147,12 +181,12 @@ def quantify_single_pattern( qy = bragg_peaks.data["qy"][keep] qx0 = bragg_peaks.data["qx"][np.logical_not(keep)] qy0 = bragg_peaks.data["qy"][np.logical_not(keep)] - if power_experiment == 0: + if power_intensity_experiment == 0: intensity = np.ones_like(qx) intensity0 = np.ones_like(qx0) else: - intensity = bragg_peaks.data["intensity"][keep]**power_experiment - intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_experiment + intensity = bragg_peaks.data["intensity"][keep]**power_intensity_experiment + intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_intensity_experiment int_total = np.sum(intensity) # init basis array @@ -190,16 +224,25 @@ def quantify_single_pattern( ), ind_orientation = m, sigma_excitation_error = sigma_excitation_error, + precession_angle_degrees = precession_angle_degrees, ) - del_peak = bragg_peaks_fit.data["qx"]**2 \ - + bragg_peaks_fit.data["qy"]**2 < tol2 + if k_max is None: + del_peak = bragg_peaks_fit.data["qx"]**2 \ + + bragg_peaks_fit.data["qy"]**2 < tol2 + else: + del_peak = np.logical_or.reduce(( + bragg_peaks_fit.data["qx"]**2 \ + + bragg_peaks_fit.data["qy"]**2 < tol2, + np.abs(bragg_peaks_fit.data["qx"]) > k_max, + np.abs(bragg_peaks_fit.data["qy"]) > k_max, + )) bragg_peaks_fit.remove(del_peak) # peak intensities - if power_calculated == 0: + if power_intensity == 0: int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) else: - int_fit = bragg_peaks_fit.data['intensity']**power_calculated + int_fit = bragg_peaks_fit.data['intensity']**power_intensity # Pair peaks to experiment if plot_result: @@ -321,20 +364,38 @@ def quantify_single_pattern( if single_phase: # loop through each crystal structure and determine the best fit structure, - # which can contain multiple orienations. + # which can contain multiple orientations up to max_number_patterns crystal_res = np.zeros(self.num_crystals) for a0 in range(self.num_crystals): - sub = self.crystal_identity[:,0] == a0 + inds_solve = self.crystal_identity[:,0] == a0 + search = True - phase_weights_cand, phase_residual_cand = nnls( - basis[:,sub], - obs, - ) - phase_weights[sub] = phase_weights_cand - crystal_res[a0] = phase_residual_cand + while search is True: + + basis_solve = basis[:,inds_solve] + obs_solve = obs.copy() + + if weight_unmatched_peaks > 1.0: + sub_unmatched = np.sum(basis_solve,axis=1)<1e-8 + obs_solve[sub_unmatched] *= weight_unmatched_peaks + + phase_weights_cand, phase_residual_cand = nnls( + basis_solve, + obs_solve, + ) + + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + phase_weights[inds_solve] = phase_weights_cand + crystal_res[a0] = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False ind_best_fit = np.argmin(crystal_res) + # ind_best_fit = np.argmax(phase_weights) + phase_residual = crystal_res[ind_best_fit] sub = np.logical_not( self.crystal_identity[:,0] == ind_best_fit @@ -452,19 +513,22 @@ def quantify_single_pattern( ax.scatter( qy0, qx0, - s = scale_markers_experiment * intensity0, + # s = scale_markers_experiment * intensity0, + s = scale_markers_experiment * bragg_peaks.data["intensity"][np.logical_not(keep)], marker = "o", facecolor = [0.7, 0.7, 0.7], ) ax.scatter( qy, qx, - s = scale_markers_experiment * intensity, + # s = scale_markers_experiment * intensity, + s = scale_markers_experiment * bragg_peaks.data["intensity"][keep], marker = "o", facecolor = [0.7, 0.7, 0.7], ) # legend - k_max = np.max(self.k_max) + if k_max is None: + k_max = np.max(self.k_max) dx_leg = -0.05*k_max dy_leg = 0.04*k_max text_params = { @@ -607,7 +671,7 @@ def quantify_single_pattern( # appearance ax.set_xlim((-k_max, k_max)) - ax.set_ylim((-k_max, k_max)) + ax.set_ylim((k_max, -k_max)) ax_leg.set_xlim((-0.1*k_max, 0.4*k_max)) ax_leg.set_ylim((-0.5*k_max, 0.5*k_max)) @@ -623,9 +687,14 @@ def quantify_phase( pointlistarray: PointListArray, corr_kernel_size = 0.04, sigma_excitation_error = 0.02, - power_experiment = 0.25, - power_calculated = 0.25, - max_number_patterns = 3, + + precession_angle_degrees = None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max = None, + # power_experiment = 0.25, + # power_calculated = 0.25, + max_number_patterns = 2, single_phase = False, allow_strain = True, strain_iterations = 3, @@ -670,8 +739,12 @@ def quantify_phase( xy_position = (rx,ry), corr_kernel_size = corr_kernel_size, sigma_excitation_error = sigma_excitation_error, - power_experiment = power_experiment, - power_calculated = power_calculated, + + precession_angle_degrees = precession_angle_degrees, + power_intensity = power_intensity, + power_intensity_experiment = power_intensity_experiment, + k_max = k_max, + max_number_patterns = max_number_patterns, single_phase = single_phase, allow_strain = allow_strain, @@ -923,10 +996,15 @@ def plot_phase_maps( def plot_dominant_phase( self, + use_correlation_scores = False, reliability_range = (0.0,1.0), sigma = 0.0, phase_colors = None, + ticks = True, figsize = (6,6), + legend_add = True, + legend_fraction = 0.2, + print_fractions = False, ): """ Plot a combined figure showing the primary phase at each probe position. @@ -950,10 +1028,18 @@ def plot_dominant_phase( phase_corr_2nd = np.zeros(scan_shape) phase_sig = np.zeros((self.num_crystals,scan_shape[0],scan_shape[1])) - # sum up phase weights by crystal type - for a0 in range(self.num_fits): - ind = self.crystal_identity[a0,0] - phase_sig[ind] += self.phase_weights[:,:,a0] + if use_correlation_scores: + # Calculate scores from highest correlation match + for a0 in range(self.num_crystals): + phase_sig[a0] = np.maximum( + phase_sig[a0], + np.max(self.crystals[a0].orientation_map.corr,axis=2), + ) + else: + # sum up phase weights by crystal type + for a0 in range(self.num_fits): + ind = self.crystal_identity[a0,0] + phase_sig[ind] += self.phase_weights[:,:,a0] # smoothing of the outputs if sigma > 0.0: @@ -992,25 +1078,68 @@ def plot_dominant_phase( 0, 1) + # Print the total area of fraction of each phase + if print_fractions: + phase_mask = phase_scale >= 0.5 + phase_total = np.sum(phase_mask) + + print('Phase Fractions') + print('---------------') + for a0 in range(self.num_crystals): + phase_frac = np.sum((phase_map == a0) * phase_mask) / phase_total - phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) + print(self.crystal_names[a0] + ' - ' + f'{phase_frac*100:.4f}' + '%') + + + self.phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) for a0 in range(self.num_crystals): sub = phase_map==a0 for a1 in range(3): - phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] + self.phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] # normalize - # phase_rgb = np.clip( - # (phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), + # self.phase_rgb = np.clip( + # (self.phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), # 0,1) - fig,ax = plt.subplots(figsize=figsize) + fig = plt.figure(figsize=figsize) + if legend_add: + width = 1 + + ax = fig.add_axes((0,legend_fraction,1,1-legend_fraction)) + ax_leg = fig.add_axes((0,0,1,legend_fraction)) + + for a0 in range(self.num_crystals): + ax_leg.scatter( + a0*width, + 0, + s = 200, + marker = 's', + edgecolor = (0,0,0,1), + facecolor = phase_colors[a0], + ) + ax_leg.text( + a0*width+0.1, + 0, + self.crystal_names[a0], + fontsize = 16, + verticalalignment = 'center', + ) + ax_leg.axis('off') + ax_leg.set_xlim(( + width * -0.5, + width * (self.num_crystals+0.5), + )) + + else: + ax = fig.add_axes((0,0,1,1)) + ax.imshow( - phase_rgb, + self.phase_rgb, # vmin = 0, # vmax = 5, - # phase_rgb, + # self.phase_rgb, # phase_corr - phase_corr_2nd, # cmap = 'turbo', # vmin = 0, @@ -1018,6 +1147,12 @@ def plot_dominant_phase( # cmap = 'gray', ) + if ticks is False: + ax.set_xticks([]) + ax.set_yticks([]) + + + return fig,ax diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index fdb313cbb..649f26faa 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -393,7 +393,7 @@ def plot_scattering_intensity( k_step=0.001, k_broadening=0.0, k_power_scale=0.0, - int_power_scale=0.5, + int_power_scale=1.0, int_scale=1.0, remove_origin=True, bragg_peaks=None, @@ -948,12 +948,27 @@ def plot_diffraction_pattern( scale_markers: float = 500, scale_markers_compare: Optional[float] = None, power_markers: float = 1, + power_markers_compare: float = 1, + + color = (0.0, 0.0, 0.0), + color_compare = None, + facecolor = None, + facecolor_compare = (0.0, 0.7, 1.0), + edgecolor = None, + edgecolor_compare = None, + linewidth = 1, + linewidth_compare = 1, + + marker = '+', + marker_compare = 'o', + plot_range_kx_ky: Optional[Union[list, tuple, np.ndarray]] = None, add_labels: bool = True, shift_labels: float = 0.08, shift_marker: float = 0.005, min_marker_size: float = 1e-6, max_marker_size: float = 1000, + show_axes: bool = True, figsize: Union[list, tuple, np.ndarray] = (12, 6), returnfig: bool = False, input_fig_handle=None, @@ -990,7 +1005,7 @@ def plot_diffraction_pattern( marker_size = scale_markers * bragg_peaks.data["intensity"] else: marker_size = scale_markers * ( - bragg_peaks.data["intensity"] ** (power_markers / 2) + bragg_peaks.data["intensity"] ** power_markers ) # Apply marker size limits to primary plot @@ -1013,7 +1028,7 @@ def plot_diffraction_pattern( else: marker_size_compare = np.clip( scale_markers_compare - * (bragg_peaks_compare.data["intensity"] ** (power_markers / 2)), + * (bragg_peaks_compare.data["intensity"] ** power_markers), min_marker_size, max_marker_size, ) @@ -1022,19 +1037,30 @@ def plot_diffraction_pattern( bragg_peaks_compare.data["qy"], bragg_peaks_compare.data["qx"], s=marker_size_compare, - marker="o", - facecolor=[0.0, 0.7, 1.0], + marker=marker_compare, + facecolor=facecolor_compare, + edgecolor=edgecolor_compare, + color=color_compare, + linewidth=linewidth_compare, ) ax.scatter( bragg_peaks.data["qy"], bragg_peaks.data["qx"], s=marker_size, - marker="+", - facecolor="k", + marker=marker, + facecolor=facecolor, + edgecolor=edgecolor, + color=color, + linewidth=linewidth, ) - ax.set_xlabel("$q_y$ [Å$^{-1}$]") - ax.set_ylabel("$q_x$ [Å$^{-1}$]") + if show_axes: + ax.set_xlabel("$q_y$ [Å$^{-1}$]") + ax.set_ylabel("$q_x$ [Å$^{-1}$]") + else: + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + if plot_range_kx_ky is not None: plot_range_kx_ky = np.array(plot_range_kx_ky) From 7f3eca729c618322b24ee1f07821f061e3202487 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 24 Jun 2024 09:01:21 -0700 Subject: [PATCH 49/64] adding robust fitting to polar --- py4DSTEM/process/polar/polar_fits.py | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index 7760726a7..ea83cc866 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -16,6 +16,9 @@ def fit_amorphous_ring( show_fit_mask=False, fit_all_images=False, maxfev=None, + robust=False, + robust_steps = 3, + robust_thresh = 1.0, verbose=False, plot_result=True, plot_log_scale=False, @@ -50,6 +53,13 @@ def fit_amorphous_ring( Fit the elliptic parameters to all images maxfev: int Max number of fitting evaluations for curve_fit. + robust: bool + Set to True to use robust fitting. + robust_steps: int + Number of robust fitting steps. + robust_thresh: float + Threshold for relative errors for outlier detection. Setting to 1.0 means all points beyond + one standard deviation of the median error will be excluded from the next fit. verbose: bool Print fit results plot_result: bool @@ -206,8 +216,41 @@ def fit_amorphous_ring( maxfev=maxfev, )[0] coefs[4] = np.mod(coefs[4], 2 * np.pi) + + if robust: + for a0 in range(robust_steps): + # find outliers + int_fit = amorphous_model(basis, *coefs) + int_diff = vals / int_mean - int_fit + int_diff /= np.median(np.abs(int_diff)) + sub_fit = int_diff**2 < robust_thresh**2 + + # redo fits excluding the outliers + if maxfev is None: + coefs = curve_fit( + amorphous_model, + basis[:,sub_fit], + vals[sub_fit] / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + )[0] + else: + coefs = curve_fit( + amorphous_model, + basis[:,sub_fit], + vals[sub_fit] / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + maxfev=maxfev, + )[0] + coefs[4] = np.mod(coefs[4], 2 * np.pi) + + # Scale intensity coefficients coefs[5:8] *= int_mean + # Perform the fit on each individual diffration pattern if fit_all_images: coefs_all = np.zeros((datacube.shape[0], datacube.shape[1], coefs.size)) From d52eaed4be9bac79d3036d3c9f685bd29cb0e9ce Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 26 Jun 2024 21:09:47 -0700 Subject: [PATCH 50/64] Fix for the infinite projection error --- py4DSTEM/process/diffraction/crystal.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 8b59c3c92..9765d9d22 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -978,6 +978,7 @@ def generate_projected_potential( potential_radius_angstroms = 3.0, sigma_image_blur_angstroms = 0.1, thickness_angstroms = 100, + max_num_proj = 200, power_scale = 1.0, plot_result = False, figsize = (6,6), @@ -1006,6 +1007,9 @@ def generate_projected_potential( thickness_angstroms: float Thickness of the sample in Angstroms. Set thickness_thickness_angstroms = 0 to skip thickness projection. + max_num_proj: int + This value prevents this function from projecting a large number of unit + cells along the beam direction, which could be potentially quite slow. power_scale: float Power law scaling of potentials. Set to 2.0 to approximate Z^2 images. plot_result: bool @@ -1055,10 +1059,17 @@ def generate_projected_potential( lat_real = self.lat_real.copy() @ orientation_matrix # Determine unit cell axes to tile over, by selecting 2/3 with largest in-plane component + # inds_tile = np.argsort( + # np.linalg.norm(lat_real[:,0:2],axis=1) + # )[1:3] + # m_tile = lat_real[inds_tile,:] + + # Determine unit cell axes to tile over, by selecting 2/3 with smallest out-of-plane component inds_tile = np.argsort( - np.linalg.norm(lat_real[:,0:2],axis=1) - )[1:3] + np.abs(lat_real[:,2]) + )[0:2] m_tile = lat_real[inds_tile,:] + # Vector projected along optic axis m_proj = np.squeeze(np.delete(lat_real,inds_tile,axis=0)) @@ -1126,6 +1137,9 @@ def generate_projected_potential( thickness_proj = thickness_angstroms / m_proj[2] vec_proj = thickness_proj / pixel_size_angstroms * m_proj[:2] num_proj = (np.ceil(np.linalg.norm(vec_proj))+1).astype('int') + + num_proj = np.minimum(num_proj, max_num_proj) + x_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[0] y_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[1] From 0763f9836ff3a02e50dbbda73b74498e931fb391 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 26 Jun 2024 21:10:40 -0700 Subject: [PATCH 51/64] cleaning up --- py4DSTEM/process/diffraction/crystal.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 9765d9d22..81962d010 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1058,12 +1058,6 @@ def generate_projected_potential( # Rotate unit cell into projection direction lat_real = self.lat_real.copy() @ orientation_matrix - # Determine unit cell axes to tile over, by selecting 2/3 with largest in-plane component - # inds_tile = np.argsort( - # np.linalg.norm(lat_real[:,0:2],axis=1) - # )[1:3] - # m_tile = lat_real[inds_tile,:] - # Determine unit cell axes to tile over, by selecting 2/3 with smallest out-of-plane component inds_tile = np.argsort( np.abs(lat_real[:,2]) @@ -1137,7 +1131,6 @@ def generate_projected_potential( thickness_proj = thickness_angstroms / m_proj[2] vec_proj = thickness_proj / pixel_size_angstroms * m_proj[:2] num_proj = (np.ceil(np.linalg.norm(vec_proj))+1).astype('int') - num_proj = np.minimum(num_proj, max_num_proj) x_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[0] From be4d9194ce1ea8987650d3aa62e231b046a37b08 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 26 Jun 2024 21:11:53 -0700 Subject: [PATCH 52/64] Black formatting --- py4DSTEM/process/diffraction/crystal.py | 171 +++++++++--------- py4DSTEM/process/utils/single_atom_scatter.py | 10 +- 2 files changed, 92 insertions(+), 89 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 81962d010..98e8c95c7 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -968,20 +968,17 @@ def generate_ring_pattern( if return_calc is True: return radii_unique, intensity_unique - - - def generate_projected_potential( self, - im_size = (256,256), - pixel_size_angstroms = 0.1, - potential_radius_angstroms = 3.0, - sigma_image_blur_angstroms = 0.1, - thickness_angstroms = 100, - max_num_proj = 200, - power_scale = 1.0, - plot_result = False, - figsize = (6,6), + im_size=(256, 256), + pixel_size_angstroms=0.1, + potential_radius_angstroms=3.0, + sigma_image_blur_angstroms=0.1, + thickness_angstroms=100, + max_num_proj=200, + power_scale=1.0, + plot_result=False, + figsize=(6, 6), orientation: Optional[Orientation] = None, ind_orientation: Optional[int] = 0, orientation_matrix: Optional[np.ndarray] = None, @@ -1005,7 +1002,7 @@ def generate_projected_potential( sigma_image_blur_angstroms: float Image blurring in Angstroms. thickness_angstroms: float - Thickness of the sample in Angstroms. + Thickness of the sample in Angstroms. Set thickness_thickness_angstroms = 0 to skip thickness projection. max_num_proj: int This value prevents this function from projecting a large number of unit @@ -1016,26 +1013,26 @@ def generate_projected_potential( Plot the projected potential image. figsize: (2,) vector giving the size of the output. - - orientation: Orientation + + orientation: Orientation An Orientation class object ind_orientation: int If input is an Orientation class object with multiple orientations, this input can be used to select a specific orientation. - orientation_matrix: array + orientation_matrix: array (3,3) orientation matrix, where columns represent projection directions. - zone_axis_lattice: array + zone_axis_lattice: array (3,) projection direction in lattice indices - proj_x_lattice: array) + proj_x_lattice: array) (3,) x-axis direction in lattice indices - zone_axis_cartesian: array + zone_axis_cartesian: array (3,) cartesian projection direction - proj_x_cartesian: array + proj_x_cartesian: array (3,) cartesian projection direction Returns -------- - im_potential: (np.array) + im_potential: (np.array) Output image of the projected potential. """ @@ -1059,82 +1056,82 @@ def generate_projected_potential( lat_real = self.lat_real.copy() @ orientation_matrix # Determine unit cell axes to tile over, by selecting 2/3 with smallest out-of-plane component - inds_tile = np.argsort( - np.abs(lat_real[:,2]) - )[0:2] - m_tile = lat_real[inds_tile,:] + inds_tile = np.argsort(np.abs(lat_real[:, 2]))[0:2] + m_tile = lat_real[inds_tile, :] # Vector projected along optic axis - m_proj = np.squeeze(np.delete(lat_real,inds_tile,axis=0)) + m_proj = np.squeeze(np.delete(lat_real, inds_tile, axis=0)) # Determine tiling range - p_corners = np.array([ - [-im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], - [ im_size_Ang[0]*0.5,-im_size_Ang[1]*0.5, 0.0], - [-im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], - [ im_size_Ang[0]*0.5, im_size_Ang[1]*0.5, 0.0], - ]) - ab = np.linalg.lstsq( - m_tile[:,:2].T, - p_corners[:,:2].T, - rcond=None - )[0] + p_corners = np.array( + [ + [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + ] + ) + ab = np.linalg.lstsq(m_tile[:, :2].T, p_corners[:, :2].T, rcond=None)[0] ab = np.floor(ab) - a_range = np.array((np.min(ab[0])-1,np.max(ab[0])+2)) - b_range = np.array((np.min(ab[1])-1,np.max(ab[1])+2)) + a_range = np.array((np.min(ab[0]) - 1, np.max(ab[0]) + 2)) + b_range = np.array((np.min(ab[1]) - 1, np.max(ab[1]) + 2)) # Tile unit cell a_ind, b_ind, atoms_ind = np.meshgrid( - np.arange(a_range[0],a_range[1]), - np.arange(b_range[0],b_range[1]), + np.arange(a_range[0], a_range[1]), + np.arange(b_range[0], b_range[1]), np.arange(self.positions.shape[0]), ) - abc_atoms = self.positions[atoms_ind.ravel(),:] - abc_atoms[:,inds_tile[0]] += a_ind.ravel() - abc_atoms[:,inds_tile[1]] += b_ind.ravel() + abc_atoms = self.positions[atoms_ind.ravel(), :] + abc_atoms[:, inds_tile[0]] += a_ind.ravel() + abc_atoms[:, inds_tile[1]] += b_ind.ravel() xyz_atoms_ang = abc_atoms @ lat_real atoms_ID_all = self.numbers[atoms_ind.ravel()] # Center atoms on image plane - x = xyz_atoms_ang[:,0] / pixel_size_angstroms + im_size[0]/2.0 - y = xyz_atoms_ang[:,1] / pixel_size_angstroms + im_size[1]/2.0 - atoms_del = np.logical_or.reduce(( - x <= -potential_radius_angstroms/2, - y <= -potential_radius_angstroms/2, - x >= im_size[0] + potential_radius_angstroms/2, - y >= im_size[1] + potential_radius_angstroms/2, - )) + x = xyz_atoms_ang[:, 0] / pixel_size_angstroms + im_size[0] / 2.0 + y = xyz_atoms_ang[:, 1] / pixel_size_angstroms + im_size[1] / 2.0 + atoms_del = np.logical_or.reduce( + ( + x <= -potential_radius_angstroms / 2, + y <= -potential_radius_angstroms / 2, + x >= im_size[0] + potential_radius_angstroms / 2, + y >= im_size[1] + potential_radius_angstroms / 2, + ) + ) x = np.delete(x, atoms_del) y = np.delete(y, atoms_del) atoms_ID_all = np.delete(atoms_ID_all, atoms_del) # Coordinate system for atomic projected potentials potential_radius = np.ceil(potential_radius_angstroms / pixel_size_angstroms) - R = np.arange(0.5-potential_radius,potential_radius+0.5) - R_ind = R.astype('int') - R_2D = np.sqrt(R[:,None]**2 + R[None,:]**2) + R = np.arange(0.5 - potential_radius, potential_radius + 0.5) + R_ind = R.astype("int") + R_2D = np.sqrt(R[:, None] ** 2 + R[None, :] ** 2) # Lookup table for atomic projected potentials atoms_ID = np.unique(self.numbers) - atoms_lookup = np.zeros(( - atoms_ID.shape[0], - R_2D.shape[0], - R_2D.shape[1], - )) + atoms_lookup = np.zeros( + ( + atoms_ID.shape[0], + R_2D.shape[0], + R_2D.shape[1], + ) + ) for a0 in range(atoms_ID.shape[0]): atom_sf = single_atom_scatter([atoms_ID[a0]]) - atoms_lookup[a0,:,:] = atom_sf.projected_potential(atoms_ID[a0], R_2D) + atoms_lookup[a0, :, :] = atom_sf.projected_potential(atoms_ID[a0], R_2D) atoms_lookup **= power_scale - # Thickness + # Thickness if thickness_angstroms > 0: thickness_proj = thickness_angstroms / m_proj[2] vec_proj = thickness_proj / pixel_size_angstroms * m_proj[:2] - num_proj = (np.ceil(np.linalg.norm(vec_proj))+1).astype('int') + num_proj = (np.ceil(np.linalg.norm(vec_proj)) + 1).astype("int") num_proj = np.minimum(num_proj, max_num_proj) - x_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[0] - y_proj = np.linspace(-0.5,0.5,num_proj)*vec_proj[1] + x_proj = np.linspace(-0.5, 0.5, num_proj) * vec_proj[0] + y_proj = np.linspace(-0.5, 0.5, num_proj) * vec_proj[1] # initialize potential im_potential = np.zeros(im_size) @@ -1145,8 +1142,8 @@ def generate_projected_potential( if thickness_angstroms > 0: for a1 in range(num_proj): - x_ind = np.round(x[a0]+x_proj[a1]).astype('int') + R_ind - y_ind = np.round(y[a0]+y_proj[a1]).astype('int') + R_ind + x_ind = np.round(x[a0] + x_proj[a1]).astype("int") + R_ind + y_ind = np.round(y[a0] + y_proj[a1]).astype("int") + R_ind x_sub = np.logical_and( x_ind >= 0, x_ind < im_size[0], @@ -1156,11 +1153,13 @@ def generate_projected_potential( y_ind < im_size[1], ) - im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] - + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] + else: - x_ind = np.round(x[a0]).astype('int') + R_ind - y_ind = np.round(y[a0]).astype('int') + R_ind + x_ind = np.round(x[a0]).astype("int") + R_ind + y_ind = np.round(y[a0]).astype("int") + R_ind x_sub = np.logical_and( x_ind >= 0, x_ind < im_size[0], @@ -1170,7 +1169,9 @@ def generate_projected_potential( y_ind < im_size[1], ) - im_potential[x_ind[x_sub][:,None],y_ind[y_sub][None,:]] += atoms_lookup[ind][x_sub,:][:,y_sub] + im_potential[ + x_ind[x_sub][:, None], y_ind[y_sub][None, :] + ] += atoms_lookup[ind][x_sub, :][:, y_sub] if thickness_angstroms > 0: im_potential /= num_proj @@ -1181,26 +1182,28 @@ def generate_projected_potential( im_potential = gaussian_filter( im_potential, sigma_image_blur, - mode = 'nearest', - ) + mode="nearest", + ) if plot_result: # quick plotting of the result int_vals = np.sort(im_potential.ravel()) - int_range = np.array(( - int_vals[np.round(0.02*int_vals.size).astype('int')], - int_vals[np.round(0.98*int_vals.size).astype('int')], - )) + int_range = np.array( + ( + int_vals[np.round(0.02 * int_vals.size).astype("int")], + int_vals[np.round(0.98 * int_vals.size).astype("int")], + ) + ) - fig,ax = plt.subplots(figsize = figsize) + fig, ax = plt.subplots(figsize=figsize) ax.imshow( im_potential, - cmap = 'turbo', - vmin = int_range[0], - vmax = int_range[1], - ) + cmap="turbo", + vmin=int_range[0], + vmax=int_range[1], + ) ax.set_axis_off() - ax.set_aspect('equal') + ax.set_aspect("equal") return im_potential diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 9085afd93..90397560f 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -2,6 +2,7 @@ import os from scipy.special import kn + class single_atom_scatter(object): """ This class calculates the composition averaged single atom scattering factor for a @@ -58,18 +59,17 @@ def projected_potential(self, Z, R): qe = 1.60217662e-19 # Permittivity of vacuum eps_0 = 8.85418782e-12 - # Bohr's constant + # Bohr's constant a_0 = 5.29177210903e-11 fe = np.zeros_like(R) for i in range(5): # fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 - pre = 2*np.pi/bi[i]**0.5 - fe += (ai[i] / bi[i]**1.5) * \ - (kn(0, pre * R) + R * kn(1, pre * R)) + pre = 2 * np.pi / bi[i] ** 0.5 + fe += (ai[i] / bi[i] ** 1.5) * (kn(0, pre * R) + R * kn(1, pre * R)) # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) - return fe * 2 * np.pi**2# / kappa + return fe * 2 * np.pi**2 # / kappa # if units == "VA": # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe # elif units == "A": From b921a47b12ff004aedf0dda00fc68ee362887c59 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 09:08:20 -0700 Subject: [PATCH 53/64] Adding docstrings, cleaning up code --- py4DSTEM/process/diffraction/crystal.py | 12 ++-- py4DSTEM/process/diffraction/crystal_phase.py | 72 +++++++++++++++++-- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index ed4f9ca6b..9aa436cbc 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -719,7 +719,11 @@ def generate_diffraction_pattern( return_orientation_matrix=False, ): """ - Generate a single diffraction pattern, return all peaks as a pointlist. + Generate a single diffraction pattern, return all peaks as a pointlist. This function performs a + kinematical calculation, with optional precession of the beam. + + TODO - switch from numerical precession to analytic (requires geometry projection). + TODO - verify projection geometry for 2D material diffraction. Parameters ---------- @@ -740,7 +744,6 @@ def generate_diffraction_pattern( cartesian projection direction proj_x_cartesian (3,) numpy.array cartesian projection direction - foil_normal 3 element foil normal - set to None to use zone_axis proj_x_axis (3,) numpy.array @@ -748,7 +751,6 @@ def generate_diffraction_pattern( accel_voltage (float) Accelerating voltage in Volts. If not specified, we check to see if crystal already has voltage specified. - sigma_excitation_error (float) sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms tol_excitation_error_mult (float) @@ -768,8 +770,8 @@ def generate_diffraction_pattern( ---------- bragg_peaks (PointList) list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] - orientation_matrix (array) - 3x3 orientation matrix (optional) + orientation_matrix (array, optional) + 3x3 orientation matrix """ diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index c18865071..cb99a99a9 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -106,9 +106,15 @@ def quantify_single_pattern( """ Quantify the phase for a single diffraction pattern. + TODO - determine the difference between false positive peaks and unmatched peaks (if any). + Parameters ---------- + pointlistarray: (PointListArray) + Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) + xy_position: (int,int) + The (x,y) or (row,column) position to be quantified. corr_kernel_size: (float) Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] @@ -116,19 +122,73 @@ def quantify_single_pattern( The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees: (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - - power_radial: (float) - Power for scaling the correlation intensity as a function of the peak radius power_intensity: (float) - Power for scaling the correlation intensity as a function of simulated peak intensity + Power for scaling the correlation intensity as a function of simulated peak intensity. power_intensity_experiment: (float): - Power for scaling the correlation intensity as a function of experimental peak intensity + Power for scaling the correlation intensity as a function of experimental peak intensity. k_max: (float) Max k values included in fits, for both x and y directions. + max_number_patterns: int + Max number of orientations which can be included in a match. + single_phase: bool + Set to true to force result to output only the best-fit phase (minimum intensity residual). + allow_strain: bool, + Allow the simulated diffraction patterns to be distorted to improve the matches. + strain_iterations: int + Number of pattern position refinement iterations. + strain_max: float + Maximum strain fraction allowed - this value should be low, typically a few percent (~0.02). + include_false_positives: bool + Penalize patterns which generate false positive peaks. + weight_false_positives: float + Weight strength of false positive peaks. + weight_unmatched_peaks: float + Penalize unmatched peaks. + plot_result: bool + Plot the resulting fit. + plot_only_nonzero_phases: bool + Only plot phases with phase weights > 0. + plot_unmatched_peaks: bool + Plot the false postive peaks. + plot_correlation_radius: bool + In the visualization, draw the correlation radius. + scale_markers_experiment: float + Size of experimental diffraction peak markers. + scale_markers_calculated: float + Size of the calculate diffraction peak markers. + crystal_inds_plot: tuple of ints + Which crystal index / indices to plot. + phase_colors: np.array + Color of each phase, should have shape = (num_phases, 3) + figsize: (float,float) + Size of the output figure. + verbose: bool + Print the resulting fit weights to console. + returnfig: bool + Return the figure and axis handles for the plot. + + + Returns + ------- + phase_weights: (np.array) + Estimated relative fraction of each phase for all probe positions. + shape = (num_x, num_y, num_orientations) + where num_orientations is the total number of all orientations for all phases. + phase_residual: (np.array) + Residual intensity not represented by the best fit phase weighting for all probe positions. + shape = (num_x, num_y) + phase_reliability: (np.array) + Estimated reliability of match(es) for all probe positions. + Typically calculated as the best fit score minus the second best fit. + shape = (num_x, num_y) + int_total: (np.array) + Sum of experimental peak intensities for all probe positions. + shape = (num_x, num_y) + fig,ax: (optional) + matplotlib figure and axis handles """ - # tolerance tol2 = 1e-6 From cd48bea268ac02ee29f70372ed2e19e731966690 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 15:53:17 -0700 Subject: [PATCH 54/64] more docstrings added --- py4DSTEM/process/diffraction/crystal_phase.py | 141 +++++++++++++++++- 1 file changed, 136 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index cb99a99a9..f05d304c1 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -747,13 +747,10 @@ def quantify_phase( pointlistarray: PointListArray, corr_kernel_size = 0.04, sigma_excitation_error = 0.02, - precession_angle_degrees = None, power_intensity: float = 0.25, power_intensity_experiment: float = 0.25, k_max = None, - # power_experiment = 0.25, - # power_calculated = 0.25, max_number_patterns = 2, single_phase = False, allow_strain = True, @@ -765,6 +762,44 @@ def quantify_phase( ): """ Quantify phase of all diffraction patterns. + + Parameters + ---------- + pointlistarray: (PointListArray) + Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the + measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] + sigma_excitation_error: (float) + The out of plane excitation error tolerance. [1/Angstroms] + precession_angle_degrees: (float) + Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). + power_intensity: (float) + Power for scaling the correlation intensity as a function of simulated peak intensity. + power_intensity_experiment: (float): + Power for scaling the correlation intensity as a function of experimental peak intensity. + k_max: (float) + Max k values included in fits, for both x and y directions. + max_number_patterns: int + Max number of orientations which can be included in a match. + single_phase: bool + Set to true to force result to output only the best-fit phase (minimum intensity residual). + allow_strain: bool, + Allow the simulated diffraction patterns to be distorted to improve the matches. + strain_iterations: int + Number of pattern position refinement iterations. + strain_max: float + Maximum strain fraction allowed - this value should be low, typically a few percent (~0.02). + include_false_positives: bool + Penalize patterns which generate false positive peaks. + weight_false_positives: float + Weight strength of false positive peaks. + progressbar: bool + Display progress. + + Returns + ----------- + """ # init results arrays @@ -836,6 +871,33 @@ def plot_phase_weights( ): """ Plot the individual phase weight maps and residuals. + + Parameters + ---------- + weight_range: (float, float) + Plotting weight range. + weight_normalize: bool + Normalize weights before plotting. + total_intensity_normalize: bool + Normalize the total intensity. + cmap: matplotlib.cm.cmap + Colormap to use for plots. + show_ticks: bool + Show ticks on plots. + show_axes: bool + Show axes. + layout: int + Layout type for figures. + figsize: (float,float) + Size of figure panel. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + fig,ax: (optional) + Figure and axes handles. + """ # Normalization if required to total DF peak intensity @@ -931,6 +993,45 @@ def plot_phase_maps( ): """ Plot the individual phase weight maps and residuals. + + Parameters + ---------- + weight_threshold: float + Threshold for showing each phase. + weight_normalize: bool + Normalize weights before plotting. + total_intensity_normalize: bool + Normalize the total intensity. + plot_combine: bool + Combine all figures into a single plot. + crystal_inds_plot: (tuple of ints) + Which crystals to plot phase maps for. + phase_colors: np.array + (Nx3) shaped array giving the colors for each phase + show_ticks: bool + Show ticks on plots. + show_axes: bool + Show axes. + layout: int + Layout type for figures. + figsize: (float,float) + Size of figure panel. + return_phase_estimate: bool + Return the phase estimate array. + return_rgb_images: bool + Return the rgb images. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + im_all: (np.array, optional) + images showing phase maps. + im_rgb, im_rgb_all: (np.array, optional) + rgb colored output images, possibly combined + fig,ax: (optional) + Figure and axes handles. + """ if phase_colors is None: @@ -1065,10 +1166,40 @@ def plot_dominant_phase( legend_add = True, legend_fraction = 0.2, print_fractions = False, + returnfig = True, ): """ Plot a combined figure showing the primary phase at each probe position. Mask by the reliability index (best match minus 2nd best match). + + Parameters + ---------- + use_correlation_scores: bool + Set to True to use correlation scores instead of reliabiltiy from intensity residuals. + reliability_range: (float, float) + Plotting intensity range + sigma: float + Smoothing in units of probe position. + phase_colors: np.array + (N,3) shaped array giving colors of all phases + ticks: bool + Show ticks on plots. + figsize: (float,float) + Size of output figure + legend_add: bool + Add legend to plot + legend_fraction: float + Fractional size of legend in plot. + print_fractions: bool + Print the estimated fraction of all phases. + returnfig: bool + Return the figure and axes handles. + + Returns + ---------- + fig,ax: (optional) + Figure and axes handles. + """ if phase_colors is None: @@ -1212,8 +1343,8 @@ def plot_dominant_phase( ax.set_yticks([]) - - return fig,ax + if returnfig: + return fig,ax # def plot_all_phase_maps(self, map_scale_values=None, index=0): From 316bded092e46e73f59186d67d4a1e6927f3cb68 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 15:53:45 -0700 Subject: [PATCH 55/64] Cleaning up, removing residual comments --- py4DSTEM/process/diffraction/crystal_phase.py | 293 ------------------ 1 file changed, 293 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index f05d304c1..750ea4670 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -1346,296 +1346,3 @@ def plot_dominant_phase( if returnfig: return fig,ax - - # def plot_all_phase_maps(self, map_scale_values=None, index=0): - # """ - # Visualize phase maps of dataset. - - # Args: - # map_scale_values (float): Value to scale correlations by - # """ - # phase_maps = [] - # if map_scale_values is None: - # map_scale_values = [1] * len(self.orientation_maps) - # corr_sum = np.sum( - # [ - # (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m]) - # for m in range(len(self.orientation_maps)) - # ] - # ) - # for m in range(len(self.orientation_maps)): - # phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum) - # show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno") - # return - - # def plot_phase_map(self, index=0, cmap=None): - # corr_array = np.dstack( - # [maps.corr[:, :, index] for maps in self.orientation_maps] - # ) - # best_corr_score = np.max(corr_array, axis=2) - # best_match_phase = [ - # np.where(corr_array[:, :, p] == best_corr_score, True, False) - # for p in range(len(self.orientation_maps)) - # ] - - # if cmap is None: - # cm = plt.get_cmap("rainbow") - # cmap = [ - # cm(1.0 * i / len(self.orientation_maps)) - # for i in range(len(self.orientation_maps)) - # ] - - # fig, (ax) = plt.subplots(figsize=(6, 6)) - # ax.matshow( - # np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), - # cmap="gray", - # ) - # ax.axis("off") - - # for m in range(len(self.orientation_maps)): - # c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[ - # m - # ] - # cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10) - # ax.matshow( - # np.ma.array( - # self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m] - # ), - # cmap=cm, - # ) - # plt.show() - - # return - - # Potentially introduce a way to check best match out of all orientations in phase plan and plug into model - # to quantify phase - - # def phase_plan( - # self, - # method, - # zone_axis_range: np.ndarray = np.array([[0, 1, 1], [1, 1, 1]]), - # angle_step_zone_axis: float = 2.0, - # angle_coarse_zone_axis: float = None, - # angle_refine_range: float = None, - # angle_step_in_plane: float = 2.0, - # accel_voltage: float = 300e3, - # intensity_power: float = 0.25, - # tol_peak_delete=None, - # tol_distance: float = 0.01, - # fiber_axis = None, - # fiber_angles = None, - # ): - # return - - # def quantify_phase( - # self, - # pointlistarray, - # tolerance_distance=0.08, - # method="nnls", - # intensity_power=0, - # mask_peaks=None, - # ): - # """ - # Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. - - # Args: - # pointlisarray (pointlistarray): Pointlistarray to quantify phase of - # tolerance_distance (float): Distance allowed between a peak and match - # method (str): Numerical method used to quantify phase - # intensity_power (float): ... - # mask_peaks (list, optional): A pointer of which positions to mask peaks from - - # Details: - # """ - # if isinstance(pointlistarray, PointListArray): - # phase_weights = np.zeros( - # ( - # pointlistarray.shape[0], - # pointlistarray.shape[1], - # np.sum([map.num_matches for map in self.orientation_maps]), - # ) - # ) - # phase_residuals = np.zeros(pointlistarray.shape) - # for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): - # ( - # _, - # phase_weight, - # phase_residual, - # crystal_identity, - # ) = self.quantify_phase_pointlist( - # pointlistarray, - # position=[Rx, Ry], - # tolerance_distance=tolerance_distance, - # method=method, - # intensity_power=intensity_power, - # mask_peaks=mask_peaks, - # ) - # phase_weights[Rx, Ry, :] = phase_weight - # phase_residuals[Rx, Ry] = phase_residual - # self.phase_weights = phase_weights - # self.phase_residuals = phase_residuals - # self.crystal_identity = crystal_identity - # return - # else: - # return TypeError("pointlistarray must be of type pointlistarray.") - # return - - # def quantify_phase_pointlist( - # self, - # pointlistarray, - # position, - # method="nnls", - # tolerance_distance=0.08, - # intensity_power=0, - # mask_peaks=None, - # ): - # """ - # Args: - # pointlisarray (pointlistarray): Pointlistarray to quantify phase of - # position (tuple/list): Position of pointlist in pointlistarray - # tolerance_distance (float): Distance allowed between a peak and match - # method (str): Numerical method used to quantify phase - # intensity_power (float): ... - # mask_peaks (list, optional): A pointer of which positions to mask peaks from - - # Returns: - # pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - # phase_weights (np.ndarray): Weights of each phase - # phase_residuals (np.ndarray): Residuals - # crystal_identity (list): List of lists, where the each entry represents the position in the - # crystal and orientation match that is associated with the phase - # weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - # the first entry [0,0] in phase weights is associated with the first crystal - # the first match within that crystal. [0,1] is the first crystal and the - # second match within that crystal. - # """ - # # Things to add: - # # 1. Better cost for distance from peaks in pointlists - # # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - - # pointlist = pointlistarray.get_pointlist(position[0], position[1]) - # pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0) - # pointlist.remove(pl_mask) - # # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in - - # if intensity_power == 0: - # pl_intensities = np.ones(pointlist["intensity"].shape) - # else: - # pl_intensities = pointlist["intensity"] ** intensity_power - # # Prepare matches for modeling - # pointlist_peak_matches = [] - # crystal_identity = [] - - # for c in range(len(self.crystals)): - # for m in range(self.orientation_maps[c].num_matches): - # crystal_identity.append([c, m]) - # phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape)) - # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - # self.orientation_maps[c].get_orientation(position[0], position[1]), - # ind_orientation=m, - # ) - # # Find the best match peak within tolerance_distance and add value in the right position - # for d in range(pointlist["qx"].shape[0]): - # distances = [] - # for p in range(bragg_peaks_fit["qx"].shape[0]): - # distances.append( - # np.sqrt( - # (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2 - # + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2 - # ) - # ) - # ind = np.where(distances == np.min(distances))[0][0] - - # # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - # if distances[ind] <= tolerance_distance: - # ## Somewhere in this if statement is probably where better distances from the peak should be coded in - # if ( - # intensity_power == 0 - # ): # This could potentially be a different intensity_power arg - # phase_peak_match_intensities[d] = 1 ** ( - # (tolerance_distance - distances[ind]) - # / tolerance_distance - # ) - # else: - # phase_peak_match_intensities[d] = bragg_peaks_fit[ - # "intensity" - # ][ind] ** ( - # (tolerance_distance - distances[ind]) - # / tolerance_distance - # ) - # else: - # ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - # continue - - # pointlist_peak_matches.append(phase_peak_match_intensities) - # pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - # pointlist_peak_intensity_matches = ( - # pointlist_peak_intensity_matches.reshape( - # pl_intensities.shape[0], - # pointlist_peak_intensity_matches.shape[-1], - # ) - # ) - - # if len(pointlist["qx"]) > 0: - # if mask_peaks is not None: - # for i in range(len(mask_peaks)): - # if mask_peaks[i] == None: # noqa: E711 - # continue - # inds_mask = np.where( - # pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 - # )[0] - # for mask in range(len(inds_mask)): - # pointlist_peak_intensity_matches[inds_mask[mask], i] = 0 - - # if method == "nnls": - # phase_weights, phase_residuals = nnls( - # pointlist_peak_intensity_matches, pl_intensities - # ) - - # elif method == "lstsq": - # phase_weights, phase_residuals, rank, singluar_vals = lstsq( - # pointlist_peak_intensity_matches, pl_intensities, rcond=-1 - # ) - # phase_residuals = np.sum(phase_residuals) - # else: - # raise ValueError(method + " Not yet implemented. Try nnls or lstsq.") - # else: - # phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) - # phase_residuals = np.NaN - # return ( - # pointlist_peak_intensity_matches, - # phase_weights, - # phase_residuals, - # crystal_identity, - # ) - - # def plot_peak_matches( - # self, - # pointlistarray, - # position, - # tolerance_distance, - # ind_orientation, - # pointlist_peak_intensity_matches, - # ): - # """ - # A method to view how the tolerance distance impacts the peak matches associated with - # the quantify_phase_pointlist method. - - # Args: - # pointlistarray, - # position, - # tolerance_distance - # pointlist_peak_intensity_matches - # """ - # pointlist = pointlistarray.get_pointlist(position[0],position[1]) - - # for m in range(pointlist_peak_intensity_matches.shape[1]): - # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern( - # self.orientation_maps[m].get_orientation(position[0], position[1]), - # ind_orientation = ind_orientation - # ) - # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m]) - - # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) - # ax1 = plot_diffraction_pattern(pointlist,) - # return From 9ada2d86fb44d6e21165aff2b4499cb3083b2b23 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 27 Jun 2024 15:54:15 -0700 Subject: [PATCH 56/64] black --- py4DSTEM/process/diffraction/crystal.py | 74 +- py4DSTEM/process/diffraction/crystal_ACOM.py | 125 +-- py4DSTEM/process/diffraction/crystal_phase.py | 946 +++++++++--------- py4DSTEM/process/diffraction/crystal_viz.py | 30 +- py4DSTEM/process/polar/polar_fits.py | 9 +- 5 files changed, 620 insertions(+), 564 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 9aa436cbc..5cea7afb4 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -714,7 +714,7 @@ def generate_diffraction_pattern( tol_excitation_error_mult: float = 3, tol_intensity: float = 1e-5, k_max: Optional[float] = None, - precession_angle_degrees = None, + precession_angle_degrees=None, keep_qz=False, return_orientation_matrix=False, ): @@ -730,49 +730,49 @@ def generate_diffraction_pattern( orientation (Orientation) an Orientation class object - ind_orientation + ind_orientation If input is an Orientation class object with multiple orientations, this input can be used to select a specific orientation. - orientation_matrix (3,3) numpy.array + orientation_matrix (3,3) numpy.array orientation matrix, where columns represent projection directions. - zone_axis_lattice (3,) numpy.array + zone_axis_lattice (3,) numpy.array projection direction in lattice indices - proj_x_lattice (3,) numpy.array + proj_x_lattice (3,) numpy.array x-axis direction in lattice indices - zone_axis_cartesian (3,) numpy.array + zone_axis_cartesian (3,) numpy.array cartesian projection direction - proj_x_cartesian (3,) numpy.array + proj_x_cartesian (3,) numpy.array cartesian projection direction - foil_normal + foil_normal 3 element foil normal - set to None to use zone_axis - proj_x_axis (3,) numpy.array + proj_x_axis (3,) numpy.array 3 element vector defining image x axis (vertical) - accel_voltage (float) + accel_voltage (float) Accelerating voltage in Volts. If not specified, we check to see if crystal already has voltage specified. sigma_excitation_error (float) sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms tol_excitation_error_mult (float) tolerance in units of sigma for s_g inclusion - tol_intensity (numpy float) + tol_intensity (numpy float) tolerance in intensity units for inclusion of diffraction spots - k_max (float) + k_max (float) Maximum scattering vector precession_angle_degrees (float) Precession angle for library calculation. Set to None for no precession. - keep_qz (bool) + keep_qz (bool) Flag to return out-of-plane diffraction vectors return_orientation_matrix (bool) Return the orientation matrix Returns ---------- - bragg_peaks (PointList) + bragg_peaks (PointList) list of all Bragg peaks with fields [qx, qy, intensity, h, k, l] - orientation_matrix (array, optional) - 3x3 orientation matrix - + orientation_matrix (array, optional) + 3x3 orientation matrix + """ if not (hasattr(self, "wavelength") and hasattr(self, "accel_voltage")): @@ -809,7 +809,7 @@ def generate_diffraction_pattern( if foil_normal is None: sg = self.excitation_errors( g, - precession_angle_degrees = precession_angle_degrees, + precession_angle_degrees=precession_angle_degrees, ) else: foil_normal = ( @@ -818,8 +818,8 @@ def generate_diffraction_pattern( ).ravel() sg = self.excitation_errors( g, - foil_normal = foil_normal, - precession_angle_degrees = precession_angle_degrees, + foil_normal=foil_normal, + precession_angle_degrees=precession_angle_degrees, ) # Threshold for inclusion in diffraction pattern @@ -827,7 +827,7 @@ def generate_diffraction_pattern( if precession_angle_degrees is None: keep = np.abs(sg) <= sg_max else: - keep = np.min(np.abs(sg),axis=1) <= sg_max + keep = np.min(np.abs(sg), axis=1) <= sg_max # Maximum scattering angle cutoff if k_max is not None: @@ -843,10 +843,8 @@ def generate_diffraction_pattern( ) else: g_int = self.struct_factors_int[keep] * np.mean( - np.exp( - (sg[keep] ** 2) / (-2 * sigma_excitation_error**2) - ), - axis = 1, + np.exp((sg[keep] ** 2) / (-2 * sigma_excitation_error**2)), + axis=1, ) hkl = self.hkl[:, keep] @@ -1394,29 +1392,31 @@ def excitation_errors( t = np.deg2rad(precession_angle_degrees) p = np.linspace( 0, - 2.0*np.pi, + 2.0 * np.pi, precession_steps, - endpoint = False, + endpoint=False, ) if foil_normal is None: - foil_normal = np.array((0.0,0.0,-1.0)) + foil_normal = np.array((0.0, 0.0, -1.0)) k = np.reshape( - (-1/self.wavelength) * np.vstack(( - np.sin(t)*np.cos(p), - np.sin(t)*np.sin(p), - np.cos(t)*np.ones(p.size), - )), - (3,1,p.size), + (-1 / self.wavelength) + * np.vstack( + ( + np.sin(t) * np.cos(p), + np.sin(t) * np.sin(p), + np.cos(t) * np.ones(p.size), + ) + ), + (3, 1, p.size), ) - term1 = np.sum( (g[:,:,None] + k) * foil_normal[:,None,None], axis=0) - term2 = np.sum( (g[:,:,None] + 2*k) * g[:,:,None], axis=0) + term1 = np.sum((g[:, :, None] + k) * foil_normal[:, None, None], axis=0) + term2 = np.sum((g[:, :, None] + 2 * k) * g[:, :, None], axis=0) sg = np.sqrt(term1**2 - term2) - term1 return sg - def calculate_bragg_peak_histogram( self, bragg_peaks, diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 74e842817..ae71c20bc 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -28,10 +28,10 @@ def orientation_plan( accel_voltage: float = 300e3, corr_kernel_size: float = 0.08, sigma_excitation_error: float = 0.02, - precession_angle_degrees = None, + precession_angle_degrees=None, power_radial: float = 1.0, - power_intensity: float = 0.25, - power_intensity_experiment: float = 0.25, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, @@ -56,53 +56,53 @@ def orientation_plan( Setting to 'fiber' as a string will make a spherical cap around a given vector. Setting to 'auto' will use pymatgen to determine the point group symmetry of the structure and choose an appropriate zone_axis_range - angle_step_zone_axis (float): + angle_step_zone_axis (float): Approximate angular step size for zone axis search [degrees] - angle_coarse_zone_axis (float): + angle_coarse_zone_axis (float): Coarse step size for zone axis search [degrees]. Setting to None uses the same value as angle_step_zone_axis. - angle_refine_range (float): + angle_refine_range (float): Range of angles to use for zone axis refinement. Setting to None uses same value as angle_coarse_zone_axis. - angle_step_in_plane (float): + angle_step_in_plane (float): Approximate angular step size for in-plane rotation [degrees] - accel_voltage (float): + accel_voltage (float): Accelerating voltage for electrons [Volts] - corr_kernel_size (float): - Correlation kernel size length. The size of the overlap kernel between the + corr_kernel_size (float): + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] sigma_excitation_error (float): The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - - power_radial (float): + + power_radial (float): Power for scaling the correlation intensity as a function of the peak radius - power_intensity (float): + power_intensity (float): Power for scaling the correlation intensity as a function of simulated peak intensity - power_intensity_experiment (float): + power_intensity_experiment (float): Power for scaling the correlation intensity as a function of experimental peak intensity - calculate_correlation_array (bool): + calculate_correlation_array (bool): Set to false to skip calculating the correlation array. This is useful when we only want the angular range / rotation matrices. - tol_peak_delete (float): + tol_peak_delete (float): Distance to delete peaks for multiple matches. Default is kernel_size * 0.5 - tol_distance (float): + tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] - fiber_axis (float): + fiber_axis (float): (3,) vector specifying the fiber axis - fiber_angles (float): + fiber_angles (float): (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees] - cartesian_directions (bool): + cartesian_directions (bool): When set to true, all zone axes and projection directions are specified in Cartesian directions. - figsize (float): + figsize (float): (2,) vector giving the figure size - CUDA (bool): + CUDA (bool): Use CUDA for the Fourier operations. - progress_bar (bool): + progress_bar (bool): If false no progress bar is displayed, """ @@ -120,8 +120,10 @@ def orientation_plan( if precession_angle_degrees is None: self.orientation_precession_angle_degrees = None else: - self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) - self.orientation_precession_angle = np.deg2rad(np.asarray(precession_angle_degrees)) + self.orientation_precession_angle_degrees = np.asarray(precession_angle_degrees) + self.orientation_precession_angle = np.deg2rad( + np.asarray(precession_angle_degrees) + ) if tol_peak_delete is None: self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5 else: @@ -775,37 +777,46 @@ def orientation_plan( # calculate intensity of spots if precession_angle_degrees is None: - Ig = np.exp(sg[keep]**2/(-2*sigma_excitation_error**2)) + Ig = np.exp(sg[keep] ** 2 / (-2 * sigma_excitation_error**2)) else: # precession extension - prec = np.cos(np.linspace(0,2*np.pi,90,endpoint=False)) - dsg = np.tan(self.orientation_precession_angle) * np.sum(g[:2,keep]**2,axis=0) - Ig = np.mean(np.exp((sg[keep,None] + dsg[:,None]*prec[None,:])**2 \ - / (-2*sigma_excitation_error**2)), axis = 1) + prec = np.cos(np.linspace(0, 2 * np.pi, 90, endpoint=False)) + dsg = np.tan(self.orientation_precession_angle) * np.sum( + g[:2, keep] ** 2, axis=0 + ) + Ig = np.mean( + np.exp( + (sg[keep, None] + dsg[:, None] * prec[None, :]) ** 2 + / (-2 * sigma_excitation_error**2) + ), + axis=1, + ) # in-plane rotation angle phi = np.arctan2(g[1, keep], g[0, keep]) phi_ind = phi / self.orientation_gamma[1] # step size of annular bins - phi_floor = np.floor(phi_ind).astype('int') + phi_floor = np.floor(phi_ind).astype("int") dphi = phi_ind - phi_floor # write intensities into orientation plan slice radial_inds = self.orientation_shell_index[keep] - self.orientation_ref[a0, radial_inds, phi_floor] += \ - (1-dphi) * \ - np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ - np.power(self.orientation_shell_radii[radial_inds], power_radial) - self.orientation_ref[a0, radial_inds, np.mod(phi_floor+1,self.orientation_in_plane_steps)] += \ - dphi * \ - np.power(self.struct_factors_int[keep] * Ig, power_intensity) * \ - np.power(self.orientation_shell_radii[radial_inds], power_radial) - + self.orientation_ref[a0, radial_inds, phi_floor] += ( + (1 - dphi) + * np.power(self.struct_factors_int[keep] * Ig, power_intensity) + * np.power(self.orientation_shell_radii[radial_inds], power_radial) + ) + self.orientation_ref[ + a0, radial_inds, np.mod(phi_floor + 1, self.orientation_in_plane_steps) + ] += ( + dphi + * np.power(self.struct_factors_int[keep] * Ig, power_intensity) + * np.power(self.orientation_shell_radii[radial_inds], power_radial) + ) # # Loop over all peaks # for a1 in np.arange(self.g_vec_all.shape[1]): # if keep[a1]: - # for a1 in np.arange(self.g_vec_all.shape[1]): # ind_radial = self.orientation_shell_index[a1] @@ -1048,21 +1059,23 @@ def match_single_pattern( self.orientation_power_intensity_experiment, ) * np.exp( - (dqr[sub, None] ** 2 - + ( - ( - np.mod( - self.orientation_gamma[None, :] - - qphi[sub, None] - + np.pi, - 2 * np.pi, + ( + dqr[sub, None] ** 2 + + ( + ( + np.mod( + self.orientation_gamma[None, :] + - qphi[sub, None] + + np.pi, + 2 * np.pi, + ) + - np.pi ) - - np.pi + * radius ) - * radius + ** 2 ) - ** 2) - / (-2*self.orientation_kernel_size**2) + / (-2 * self.orientation_kernel_size**2) ), axis=0, ) @@ -1096,7 +1109,6 @@ def match_single_pattern( # axis=0, # ) - # im_polar[ind_radial, :] = np.sum( # np.power(radius, self.orientation_power_radial) # * np.power( @@ -1127,7 +1139,6 @@ def match_single_pattern( # axis=0, # ) - # normalization # im_polar -= np.mean(im_polar) @@ -1628,8 +1639,8 @@ def match_single_pattern( bragg_peaks_fit = self.generate_diffraction_pattern( orientation, ind_orientation=match_ind, - sigma_excitation_error = self.orientation_sigma_excitation_error, - precession_angle_degrees = self.orientation_precession_angle_degrees, + sigma_excitation_error=self.orientation_sigma_excitation_error, + precession_angle_degrees=self.orientation_precession_angle_degrees, ) remove = np.zeros_like(qx, dtype="bool") diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 750ea4670..3db718555 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -10,7 +10,6 @@ from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern - class Crystal_Phase: """ A class storing multiple crystal structures, and associated diffraction data. @@ -21,9 +20,9 @@ class Crystal_Phase: def __init__( self, crystals, - crystal_names = None, - orientation_maps = None, - name = None, + crystal_names=None, + orientation_maps=None, + name=None, ): """ Args: @@ -39,7 +38,9 @@ def __init__( # List of orientation maps if orientation_maps is None: - self.orientation_maps = [crystals[ind].orientation_map for ind in range(self.num_crystals)] + self.orientation_maps = [ + crystals[ind].orientation_map for ind in range(self.num_crystals) + ] else: if len(self.crystals) != len(orientation_maps): raise ValueError( @@ -49,60 +50,64 @@ def __init__( # Names of all crystal phases if crystal_names is None: - self.crystal_names = ['crystal' + str(ind) for ind in range(self.num_crystals)] + self.crystal_names = [ + "crystal" + str(ind) for ind in range(self.num_crystals) + ] else: self.crystal_names = crystal_names # Name of the phase map if name is None: - self.name = 'phase map' + self.name = "phase map" else: self.name = name # Get some attributes from crystals self.k_max = np.zeros(self.num_crystals) - self.num_matches = np.zeros(self.num_crystals, dtype='int') - self.crystal_identity = np.zeros((0,2), dtype='int') + self.num_matches = np.zeros(self.num_crystals, dtype="int") + self.crystal_identity = np.zeros((0, 2), dtype="int") for a0 in range(self.num_crystals): self.k_max[a0] = self.crystals[a0].k_max self.num_matches[a0] = self.crystals[a0].orientation_map.num_matches for a1 in range(self.num_matches[a0]): - self.crystal_identity = np.append(self.crystal_identity,np.array((a0,a1),dtype='int')[None,:], axis=0) + self.crystal_identity = np.append( + self.crystal_identity, + np.array((a0, a1), dtype="int")[None, :], + axis=0, + ) self.num_fits = np.sum(self.num_matches) - - def quantify_single_pattern( self, pointlistarray: PointListArray, - xy_position = (0,0), - corr_kernel_size = 0.04, + xy_position=(0, 0), + corr_kernel_size=0.04, sigma_excitation_error: float = 0.02, - precession_angle_degrees = None, - power_intensity: float = 0.25, - power_intensity_experiment: float = 0.25, - k_max = None, - max_number_patterns = 2, - single_phase = False, - allow_strain = False, - strain_iterations = 3, - strain_max = 0.02, - include_false_positives = True, - weight_false_positives = 1.0, - weight_unmatched_peaks = 1.0, - plot_result = True, - plot_only_nonzero_phases = True, - plot_unmatched_peaks = False, - plot_correlation_radius = False, - scale_markers_experiment = 40, - scale_markers_calculated = 200, - crystal_inds_plot = None, - phase_colors = None, - figsize = (10,7), - verbose = True, - returnfig = False, - ): + precession_angle_degrees=None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max=None, + max_number_patterns=2, + single_phase=False, + allow_strain=False, + strain_iterations=3, + strain_max=0.02, + include_false_positives=True, + weight_false_positives=1.0, + weight_unmatched_peaks=1.0, + plot_result=True, + plot_only_nonzero_phases=True, + plot_unmatched_peaks=False, + plot_correlation_radius=False, + scale_markers_experiment=40, + scale_markers_calculated=200, + crystal_inds_plot=None, + phase_colors=None, + figsize=(10, 7), + verbose=True, + returnfig=False, + ): """ Quantify the phase for a single diffraction pattern. @@ -110,21 +115,21 @@ def quantify_single_pattern( Parameters ---------- - + pointlistarray: (PointListArray) Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) xy_position: (int,int) The (x,y) or (row,column) position to be quantified. - corr_kernel_size: (float) - Correlation kernel size length. The size of the overlap kernel between the + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] sigma_excitation_error: (float) The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees: (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - power_intensity: (float) + power_intensity: (float) Power for scaling the correlation intensity as a function of simulated peak intensity. - power_intensity_experiment: (float): + power_intensity_experiment: (float): Power for scaling the correlation intensity as a function of experimental peak intensity. k_max: (float) Max k values included in fits, for both x and y directions. @@ -171,7 +176,7 @@ def quantify_single_pattern( Returns ------- phase_weights: (np.array) - Estimated relative fraction of each phase for all probe positions. + Estimated relative fraction of each phase for all probe positions. shape = (num_x, num_y, num_orientations) where num_orientations is the total number of all orientations for all phases. phase_residual: (np.array) @@ -193,45 +198,49 @@ def quantify_single_pattern( tol2 = 1e-6 # calibrations - center = pointlistarray.calstate['center'] - ellipse = pointlistarray.calstate['ellipse'] - pixel = pointlistarray.calstate['pixel'] - rotate = pointlistarray.calstate['rotate'] + center = pointlistarray.calstate["center"] + ellipse = pointlistarray.calstate["ellipse"] + pixel = pointlistarray.calstate["pixel"] + rotate = pointlistarray.calstate["rotate"] if center is False: - raise ValueError('Bragg peaks must be center calibration') + raise ValueError("Bragg peaks must be center calibration") if pixel is False: - raise ValueError('Bragg peaks must have pixel size calibration') + raise ValueError("Bragg peaks must have pixel size calibration") # TODO - potentially warn the user if ellipse / rotate calibration not available if phase_colors is None: - phase_colors = np.array(( - (1.0,0.0,0.0,1.0), - (0.0,0.8,1.0,1.0), - (0.0,0.6,0.0,1.0), - (1.0,0.0,1.0,1.0), - (0.0,0.2,1.0,1.0), - (1.0,0.8,0.0,1.0), - )) + phase_colors = np.array( + ( + (1.0, 0.0, 0.0, 1.0), + (0.0, 0.8, 1.0, 1.0), + (0.0, 0.6, 0.0, 1.0), + (1.0, 0.0, 1.0, 1.0), + (0.0, 0.2, 1.0, 1.0), + (1.0, 0.8, 0.0, 1.0), + ) + ) # Experimental values bragg_peaks = pointlistarray.get_vectors( xy_position[0], xy_position[1], - center = center, - ellipse = ellipse, - pixel = pixel, - rotate = rotate) + center=center, + ellipse=ellipse, + pixel=pixel, + rotate=rotate, + ) # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() if k_max is None: - keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + keep = bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2 else: - keep = np.logical_and.reduce(( - bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2, - np.abs(bragg_peaks.data["qx"]) < k_max, - np.abs(bragg_peaks.data["qy"]) < k_max, - )) + keep = np.logical_and.reduce( + ( + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2, + np.abs(bragg_peaks.data["qx"]) < k_max, + np.abs(bragg_peaks.data["qy"]) < k_max, + ) + ) - # ind_center_beam = np.argmin( # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') @@ -245,9 +254,14 @@ def quantify_single_pattern( intensity = np.ones_like(qx) intensity0 = np.ones_like(qx0) else: - intensity = bragg_peaks.data["intensity"][keep]**power_intensity_experiment - intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_intensity_experiment - int_total = np.sum(intensity) + intensity = ( + bragg_peaks.data["intensity"][keep] ** power_intensity_experiment + ) + intensity0 = ( + bragg_peaks.data["intensity"][np.logical_not(keep)] + ** power_intensity_experiment + ) + int_total = np.sum(intensity) # init basis array if include_false_positives: @@ -256,9 +270,9 @@ def quantify_single_pattern( else: basis = np.zeros((intensity.shape[0], self.num_fits)) if allow_strain: - m_strains = np.zeros((self.num_fits,2,2)) - m_strains[:,0,0] = 1.0 - m_strains[:,1,1] = 1.0 + m_strains = np.zeros((self.num_fits, 2, 2)) + m_strains[:, 0, 0] = 1.0 + m_strains[:, 1, 1] = 1.0 # kernel radius squared radius_max_2 = corr_kernel_size**2 @@ -271,8 +285,8 @@ def quantify_single_pattern( # Generate point list data, match to experimental peaks for a0 in range(self.num_fits): - c = self.crystal_identity[a0,0] - m = self.crystal_identity[a0,1] + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] # for c in range(self.num_crystals): # for m in range(self.num_matches[c]): # ind_match += 1 @@ -282,40 +296,46 @@ def quantify_single_pattern( self.crystals[c].orientation_map.get_orientation( xy_position[0], xy_position[1] ), - ind_orientation = m, - sigma_excitation_error = sigma_excitation_error, - precession_angle_degrees = precession_angle_degrees, + ind_orientation=m, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, ) if k_max is None: - del_peak = bragg_peaks_fit.data["qx"]**2 \ - + bragg_peaks_fit.data["qy"]**2 < tol2 + del_peak = ( + bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2 + < tol2 + ) else: - del_peak = np.logical_or.reduce(( - bragg_peaks_fit.data["qx"]**2 \ - + bragg_peaks_fit.data["qy"]**2 < tol2, - np.abs(bragg_peaks_fit.data["qx"]) > k_max, - np.abs(bragg_peaks_fit.data["qy"]) > k_max, - )) + del_peak = np.logical_or.reduce( + ( + bragg_peaks_fit.data["qx"] ** 2 + + bragg_peaks_fit.data["qy"] ** 2 + < tol2, + np.abs(bragg_peaks_fit.data["qx"]) > k_max, + np.abs(bragg_peaks_fit.data["qy"]) > k_max, + ) + ) bragg_peaks_fit.remove(del_peak) # peak intensities if power_intensity == 0: int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) else: - int_fit = bragg_peaks_fit.data['intensity']**power_intensity - + int_fit = bragg_peaks_fit.data["intensity"] ** power_intensity + # Pair peaks to experiment if plot_result: - matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') + matches = np.zeros((bragg_peaks_fit.data.shape[0]), dtype="bool") if allow_strain: for a1 in range(strain_iterations): # Initial peak pairing to find best-fit strain distortion - pair_sub = np.zeros(bragg_peaks_fit.data.shape[0],dtype='bool') - pair_inds = np.zeros(bragg_peaks_fit.data.shape[0],dtype='int') + pair_sub = np.zeros(bragg_peaks_fit.data.shape[0], dtype="bool") + pair_inds = np.zeros(bragg_peaks_fit.data.shape[0], dtype="int") for a1 in range(bragg_peaks_fit.data.shape[0]): - dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ - + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + dist2 = (bragg_peaks_fit.data["qx"][a1] - qx) ** 2 + ( + bragg_peaks_fit.data["qy"][a1] - qy + ) ** 2 ind_min = np.argmin(dist2) val_min = dist2[ind_min] @@ -327,19 +347,32 @@ def quantify_single_pattern( # requires at least 4 peak pairs if np.sum(pair_sub) >= 4: # pair_obs = bragg_peaks_fit.data[['qx','qy']][pair_sub] - pair_basis = np.vstack(( - bragg_peaks_fit.data['qx'][pair_sub], - bragg_peaks_fit.data['qy'][pair_sub], - )).T - pair_obs = np.vstack(( - qx[pair_inds[pair_sub]], - qy[pair_inds[pair_sub]], - )).T + pair_basis = np.vstack( + ( + bragg_peaks_fit.data["qx"][pair_sub], + bragg_peaks_fit.data["qy"][pair_sub], + ) + ).T + pair_obs = np.vstack( + ( + qx[pair_inds[pair_sub]], + qy[pair_inds[pair_sub]], + ) + ).T # weights dists = np.sqrt( - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2 + \ - (bragg_peaks_fit.data['qx'][pair_sub] - qx[pair_inds[pair_sub]])**2) + ( + bragg_peaks_fit.data["qx"][pair_sub] + - qx[pair_inds[pair_sub]] + ) + ** 2 + + ( + bragg_peaks_fit.data["qx"][pair_sub] + - qx[pair_inds[pair_sub]] + ) + ** 2 + ) weights = np.sqrt( int_fit[pair_sub] * intensity[pair_inds[pair_sub]] ) * (1 - dists / corr_kernel_size) @@ -347,9 +380,9 @@ def quantify_single_pattern( # strain tensor m_strain = np.linalg.lstsq( - pair_basis * weights[:,None], - pair_obs * weights[:,None], - rcond = None, + pair_basis * weights[:, None], + pair_obs * weights[:, None], + rcond=None, )[0] # Clamp strains to be within the user-specified limit @@ -361,34 +394,42 @@ def quantify_single_pattern( m_strains[a0] *= m_strain # Transformed peak positions - qx_copy = bragg_peaks_fit.data['qx'] - qy_copy = bragg_peaks_fit.data['qy'] - bragg_peaks_fit.data['qx'] = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - bragg_peaks_fit.data['qy'] = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + qx_copy = bragg_peaks_fit.data["qx"] + qy_copy = bragg_peaks_fit.data["qy"] + bragg_peaks_fit.data["qx"] = ( + qx_copy * m_strain[0, 0] + qy_copy * m_strain[1, 0] + ) + bragg_peaks_fit.data["qy"] = ( + qx_copy * m_strain[0, 1] + qy_copy * m_strain[1, 1] + ) # Loop over all peaks, pair experiment to library for a1 in range(bragg_peaks_fit.data.shape[0]): - dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ - + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + dist2 = (bragg_peaks_fit.data["qx"][a1] - qx) ** 2 + ( + bragg_peaks_fit.data["qy"][a1] - qy + ) ** 2 ind_min = np.argmin(dist2) val_min = dist2[ind_min] if include_false_positives: - weight = np.clip(1 - np.sqrt(dist2[ind_min]) / corr_kernel_size,0,1) - basis[ind_min,a0] = int_fit[a1] * weight - unpaired_peaks.append([ - a0, - int_fit[a1] * (1 - weight), - ]) + weight = np.clip( + 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size, 0, 1 + ) + basis[ind_min, a0] = int_fit[a1] * weight + unpaired_peaks.append( + [ + a0, + int_fit[a1] * (1 - weight), + ] + ) if weight > 1e-8 and plot_result: matches[a1] = True else: if val_min < radius_max_2: - basis[ind_min,a0] = int_fit[a1] + basis[ind_min, a0] = int_fit[a1] if plot_result: matches[a1] = True - # if val_min < radius_max_2: # # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size # # weight = 1 + corr_distance_scale * \ @@ -402,22 +443,22 @@ def quantify_single_pattern( # unpaired_peaks.append([a0,int_fit[a1]]) if plot_result: - library_peaks.append(bragg_peaks_fit) + library_peaks.append(bragg_peaks_fit) library_int.append(int_fit) library_matches.append(matches) # If needed, augment basis and observations with false positives if include_false_positives: - basis_aug = np.zeros((len(unpaired_peaks),self.num_fits)) + basis_aug = np.zeros((len(unpaired_peaks), self.num_fits)) for a0 in range(len(unpaired_peaks)): - basis_aug[a0,unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] + basis_aug[a0, unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] basis = np.vstack((basis, basis_aug * weight_false_positives)) obs = np.hstack((intensity, np.zeros(len(unpaired_peaks)))) else: obs = intensity - + # Solve for phase weight coefficients try: phase_weights = np.zeros(self.num_fits) @@ -428,16 +469,16 @@ def quantify_single_pattern( crystal_res = np.zeros(self.num_crystals) for a0 in range(self.num_crystals): - inds_solve = self.crystal_identity[:,0] == a0 + inds_solve = self.crystal_identity[:, 0] == a0 search = True while search is True: - basis_solve = basis[:,inds_solve] + basis_solve = basis[:, inds_solve] obs_solve = obs.copy() if weight_unmatched_peaks > 1.0: - sub_unmatched = np.sum(basis_solve,axis=1)<1e-8 + sub_unmatched = np.sum(basis_solve, axis=1) < 1e-8 obs_solve[sub_unmatched] *= weight_unmatched_peaks phase_weights_cand, phase_residual_cand = nnls( @@ -445,7 +486,10 @@ def quantify_single_pattern( obs_solve, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): phase_weights[inds_solve] = phase_weights_cand crystal_res[a0] = phase_residual_cand search = False @@ -457,9 +501,7 @@ def quantify_single_pattern( # ind_best_fit = np.argmax(phase_weights) phase_residual = crystal_res[ind_best_fit] - sub = np.logical_not( - self.crystal_identity[:,0] == ind_best_fit - ) + sub = np.logical_not(self.crystal_identity[:, 0] == ind_best_fit) phase_weights[sub] = 0.0 # Estimate reliability as difference between best fit and 2nd best fit @@ -468,15 +510,18 @@ def quantify_single_pattern( else: # Allow all crystals and orientation matches in the pattern - inds_solve = np.ones(self.num_fits,dtype='bool') + inds_solve = np.ones(self.num_fits, dtype="bool") search = True while search is True: phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], + basis[:, inds_solve], obs, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): phase_weights[inds_solve] = phase_weights_cand phase_residual = phase_residual_cand search = False @@ -484,9 +529,8 @@ def quantify_single_pattern( inds = np.where(inds_solve)[0] inds_solve[inds[np.argmin(phase_weights_cand)]] = False - # Estimate the phase reliability - inds_solve = np.ones(self.num_fits,dtype='bool') + inds_solve = np.ones(self.num_fits, dtype="bool") inds_solve[phase_weights > 1e-8] = False if np.all(inds_solve == False): @@ -495,10 +539,13 @@ def quantify_single_pattern( search = True while search is True: phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], + basis[:, inds_solve], obs, ) - if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_patterns: + if ( + np.count_nonzero(phase_weights_cand > 0.0) + <= max_number_patterns + ): phase_residual_2nd = phase_residual_cand search = False else: @@ -506,7 +553,7 @@ def quantify_single_pattern( inds_solve[inds[np.argmin(phase_weights_cand)]] = False phase_weights_cand, phase_residual_cand = nnls( - basis[:,inds_solve], + basis[:, inds_solve], obs, ) phase_reliability = phase_residual_2nd - phase_residual @@ -516,34 +563,25 @@ def quantify_single_pattern( phase_residual = np.sqrt(np.sum(intensity**2)) phase_reliability = 0.0 - if verbose: ind_max = np.argmax(phase_weights) # print() - print('\033[1m' + 'phase_weight or_ind name' + '\033[0m') + print("\033[1m" + "phase_weight or_ind name" + "\033[0m") # print() for a0 in range(self.num_fits): - c = self.crystal_identity[a0,0] - m = self.crystal_identity[a0,1] - line = '{:>12} {:>8} {:<12}'.format( - f'{phase_weights[a0]:.2f}', - m, - self.crystal_names[c] - ) + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] + line = "{:>12} {:>8} {:<12}".format( + f"{phase_weights[a0]:.2f}", m, self.crystal_names[c] + ) if a0 == ind_max: - print('\033[1m' + line + '\033[0m') + print("\033[1m" + line + "\033[0m") else: print(line) - print('----------------------------') - line = '{:>12} {:>15}'.format( - f'{sum(phase_weights):.2f}', - 'fit total' - ) - print('\033[1m' + line + '\033[0m') - line = '{:>12} {:>15}'.format( - f'{phase_residual:.2f}', - 'fit residual' - ) + print("----------------------------") + line = "{:>12} {:>15}".format(f"{sum(phase_weights):.2f}", "fit total") + print("\033[1m" + line + "\033[0m") + line = "{:>12} {:>15}".format(f"{phase_residual:.2f}", "fit residual") print(line) # Plotting @@ -558,39 +596,40 @@ def quantify_single_pattern( if plot_correlation_radius: # plot the experimental radii - t = np.linspace(0,2*np.pi,91,endpoint=True) + t = np.linspace(0, 2 * np.pi, 91, endpoint=True) ct = np.cos(t) * corr_kernel_size st = np.sin(t) * corr_kernel_size for a0 in range(qx.shape[0]): ax.plot( qy[a0] + st, qx[a0] + ct, - color = 'k', - linewidth = 1, - ) + color="k", + linewidth=1, + ) # plot the experimental peaks ax.scatter( qy0, qx0, # s = scale_markers_experiment * intensity0, - s = scale_markers_experiment * bragg_peaks.data["intensity"][np.logical_not(keep)], - marker = "o", - facecolor = [0.7, 0.7, 0.7], - ) + s=scale_markers_experiment + * bragg_peaks.data["intensity"][np.logical_not(keep)], + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) ax.scatter( qy, qx, # s = scale_markers_experiment * intensity, - s = scale_markers_experiment * bragg_peaks.data["intensity"][keep], - marker = "o", - facecolor = [0.7, 0.7, 0.7], - ) + s=scale_markers_experiment * bragg_peaks.data["intensity"][keep], + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) # legend if k_max is None: k_max = np.max(self.k_max) - dx_leg = -0.05*k_max - dy_leg = 0.04*k_max + dx_leg = -0.05 * k_max + dy_leg = 0.04 * k_max text_params = { "va": "center", "ha": "left", @@ -601,34 +640,25 @@ def quantify_single_pattern( } if plot_correlation_radius: ax_leg.plot( - 0 + st*0.5, - -dx_leg + ct*0.5, - color = 'k', - linewidth = 1, - ) + 0 + st * 0.5, + -dx_leg + ct * 0.5, + color="k", + linewidth=1, + ) ax_leg.scatter( 0, 0, - s = 200, - marker = "o", - facecolor = [0.7, 0.7, 0.7], - ) - ax_leg.text( - dy_leg, - 0, - 'Experimental peaks', - **text_params) + s=200, + marker="o", + facecolor=[0.7, 0.7, 0.7], + ) + ax_leg.text(dy_leg, 0, "Experimental peaks", **text_params) if plot_correlation_radius: - ax_leg.text( - dy_leg, - -dx_leg, - 'Correlation radius', - **text_params) - + ax_leg.text(dy_leg, -dx_leg, "Correlation radius", **text_params) # plot calculated diffraction patterns uvals = phase_colors.copy() - uvals[:,3] = 0.3 + uvals[:, 3] = 0.3 # uvals = np.array(( # (1.0,0.0,0.0,0.2), # (0.0,0.8,1.0,0.2), @@ -637,25 +667,35 @@ def quantify_single_pattern( # (0.0,0.2,1.0,0.2), # (1.0,0.8,0.0,0.2), # )) - mvals = ['v','^','<','>','d','s',] + mvals = [ + "v", + "^", + "<", + ">", + "d", + "s", + ] count_leg = 0 for a0 in range(self.num_fits): - c = self.crystal_identity[a0,0] - m = self.crystal_identity[a0,1] + c = self.crystal_identity[a0, 0] + m = self.crystal_identity[a0, 1] - if crystal_inds_plot == None or np.min(np.abs(c - crystal_inds_plot)) == 0: + if ( + crystal_inds_plot == None + or np.min(np.abs(c - crystal_inds_plot)) == 0 + ): - qx_fit = library_peaks[a0].data['qx'] - qy_fit = library_peaks[a0].data['qy'] + qx_fit = library_peaks[a0].data["qx"] + qy_fit = library_peaks[a0].data["qy"] if allow_strain: m_strain = m_strains[a0] # Transformed peak positions qx_copy = qx_fit.copy() qy_copy = qy_fit.copy() - qx_fit = qx_copy*m_strain[0,0] + qy_copy*m_strain[1,0] - qy_fit = qx_copy*m_strain[0,1] + qy_copy*m_strain[1,1] + qx_fit = qx_copy * m_strain[0, 0] + qy_copy * m_strain[1, 0] + qy_fit = qx_copy * m_strain[0, 1] + qy_copy * m_strain[1, 1] int_fit = library_int[a0] matches_fit = library_matches[a0] @@ -666,33 +706,35 @@ def quantify_single_pattern( ax.scatter( qy_fit[matches_fit], qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) + s=scale_markers_calculated * int_fit[matches_fit], + marker=mvals[c], + facecolor=phase_colors[c, :], + ) if plot_unmatched_peaks: ax.scatter( qy_fit[np.logical_not(matches_fit)], qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) + s=scale_markers_calculated + * int_fit[np.logical_not(matches_fit)], + marker=mvals[c], + facecolor=phase_colors[c, :], + ) # legend if m == 0: ax_leg.text( dy_leg, - (count_leg+1)*dx_leg, + (count_leg + 1) * dx_leg, self.crystal_names[c], - **text_params) + **text_params, + ) ax_leg.scatter( 0, - (count_leg+1) * dx_leg, - s = 200, - marker = mvals[c], - facecolor = phase_colors[c,:], - ) + (count_leg + 1) * dx_leg, + s=200, + marker=mvals[c], + facecolor=phase_colors[c, :], + ) count_leg += 1 # else: # ax.scatter( @@ -727,14 +769,12 @@ def quantify_single_pattern( # # facecolors = (1,1,1,0.5), # ) - - # appearance ax.set_xlim((-k_max, k_max)) ax.set_ylim((k_max, -k_max)) - ax_leg.set_xlim((-0.1*k_max, 0.4*k_max)) - ax_leg.set_ylim((-0.5*k_max, 0.5*k_max)) + ax_leg.set_xlim((-0.1 * k_max, 0.4 * k_max)) + ax_leg.set_ylim((-0.5 * k_max, 0.5 * k_max)) ax_leg.set_axis_off() if returnfig: @@ -745,21 +785,21 @@ def quantify_single_pattern( def quantify_phase( self, pointlistarray: PointListArray, - corr_kernel_size = 0.04, - sigma_excitation_error = 0.02, - precession_angle_degrees = None, - power_intensity: float = 0.25, - power_intensity_experiment: float = 0.25, - k_max = None, - max_number_patterns = 2, - single_phase = False, - allow_strain = True, - strain_iterations = 3, - strain_max = 0.02, - include_false_positives = True, - weight_false_positives = 1.0, - progress_bar = True, - ): + corr_kernel_size=0.04, + sigma_excitation_error=0.02, + precession_angle_degrees=None, + power_intensity: float = 0.25, + power_intensity_experiment: float = 0.25, + k_max=None, + max_number_patterns=2, + single_phase=False, + allow_strain=True, + strain_iterations=3, + strain_max=0.02, + include_false_positives=True, + weight_false_positives=1.0, + progress_bar=True, + ): """ Quantify phase of all diffraction patterns. @@ -767,16 +807,16 @@ def quantify_phase( ---------- pointlistarray: (PointListArray) Full array of all calibrated experimental bragg peaks, with shape = (num_x,num_y) - corr_kernel_size: (float) - Correlation kernel size length. The size of the overlap kernel between the + corr_kernel_size: (float) + Correlation kernel size length. The size of the overlap kernel between the measured Bragg peaks and diffraction library Bragg peaks. [1/Angstroms] sigma_excitation_error: (float) The out of plane excitation error tolerance. [1/Angstroms] precession_angle_degrees: (float) Tilt angle of illuminaiton cone in degrees for precession electron diffraction (PED). - power_intensity: (float) + power_intensity: (float) Power for scaling the correlation intensity as a function of simulated peak intensity. - power_intensity_experiment: (float): + power_intensity_experiment: (float): Power for scaling the correlation intensity as a function of experimental peak intensity. k_max: (float) Max k values included in fits, for both x and y directions. @@ -803,72 +843,79 @@ def quantify_phase( """ # init results arrays - self.phase_weights = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - self.num_fits, - )) - self.phase_residuals = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - )) - self.phase_reliability = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - )) - self.int_total = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - )) + self.phase_weights = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + self.num_fits, + ) + ) + self.phase_residuals = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) + self.phase_reliability = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) + self.int_total = np.zeros( + ( + pointlistarray.shape[0], + pointlistarray.shape[1], + ) + ) self.single_phase = single_phase - + for rx, ry in tqdmnd( - *pointlistarray.shape, - desc="Quantifying Phase", - unit=" PointList", - disable=not progress_bar, - ): + *pointlistarray.shape, + desc="Quantifying Phase", + unit=" PointList", + disable=not progress_bar, + ): # calculate phase weights - phase_weights, phase_residual, phase_reliability, int_peaks = self.quantify_single_pattern( - pointlistarray = pointlistarray, - xy_position = (rx,ry), - corr_kernel_size = corr_kernel_size, - sigma_excitation_error = sigma_excitation_error, - - precession_angle_degrees = precession_angle_degrees, - power_intensity = power_intensity, - power_intensity_experiment = power_intensity_experiment, - k_max = k_max, - - max_number_patterns = max_number_patterns, - single_phase = single_phase, - allow_strain = allow_strain, - strain_iterations = strain_iterations, - strain_max = strain_max, - include_false_positives = include_false_positives, - weight_false_positives = weight_false_positives, - plot_result = False, - verbose = False, - returnfig = False, + phase_weights, phase_residual, phase_reliability, int_peaks = ( + self.quantify_single_pattern( + pointlistarray=pointlistarray, + xy_position=(rx, ry), + corr_kernel_size=corr_kernel_size, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, + power_intensity=power_intensity, + power_intensity_experiment=power_intensity_experiment, + k_max=k_max, + max_number_patterns=max_number_patterns, + single_phase=single_phase, + allow_strain=allow_strain, + strain_iterations=strain_iterations, + strain_max=strain_max, + include_false_positives=include_false_positives, + weight_false_positives=weight_false_positives, + plot_result=False, + verbose=False, + returnfig=False, ) - self.phase_weights[rx,ry] = phase_weights - self.phase_residuals[rx,ry] = phase_residual - self.phase_reliability[rx,ry] = phase_reliability - self.int_total[rx,ry] = int_peaks - + ) + self.phase_weights[rx, ry] = phase_weights + self.phase_residuals[rx, ry] = phase_residual + self.phase_reliability[rx, ry] = phase_reliability + self.int_total[rx, ry] = int_peaks def plot_phase_weights( self, - weight_range = (0.0,1.0), - weight_normalize = False, - total_intensity_normalize = True, - cmap = 'gray', - show_ticks = False, - show_axes = True, - layout = 0, - figsize = (6,6), - returnfig = False, - ): + weight_range=(0.0, 1.0), + weight_normalize=False, + total_intensity_normalize=True, + cmap="gray", + show_ticks=False, + show_axes=True, + layout=0, + figsize=(6, 6), + returnfig=False, + ): """ Plot the individual phase weight maps and residuals. @@ -906,43 +953,45 @@ def plot_phase_weights( if total_intensity_normalize: sub = self.int_total > 0.0 for a0 in range(self.num_fits): - phase_weights[:,:,a0][sub] /= self.int_total[sub] + phase_weights[:, :, a0][sub] /= self.int_total[sub] phase_residuals[sub] /= self.int_total[sub] # intensity range for plotting if weight_normalize: - scale = np.median(np.max(phase_weights,axis=2)) + scale = np.median(np.max(phase_weights, axis=2)) else: scale = 1 weight_range = np.array(weight_range) * scale # plotting if layout == 0: - fig,ax = plt.subplots( + fig, ax = plt.subplots( 1, self.num_crystals + 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) elif layout == 1: - fig,ax = plt.subplots( + fig, ax = plt.subplots( self.num_crystals + 1, 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) for a0 in range(self.num_crystals): - sub = self.crystal_identity[:,0] == a0 - im = np.sum(phase_weights[:,:,sub],axis=2) + sub = self.crystal_identity[:, 0] == a0 + im = np.sum(phase_weights[:, :, sub], axis=2) im = np.clip( - (im - weight_range[0]) / (weight_range[1] - weight_range[0]), - 0,1) + (im - weight_range[0]) / (weight_range[1] - weight_range[0]), 0, 1 + ) ax[a0].imshow( im, - vmin = 0, - vmax = 1, - cmap = cmap, + vmin=0, + vmax=1, + cmap=cmap, ) ax[a0].set_title( self.crystal_names[a0], - fontsize = 16, + fontsize=16, ) if not show_ticks: ax[a0].set_xticks([]) @@ -952,18 +1001,19 @@ def plot_phase_weights( # plot residuals im = np.clip( - (phase_residuals - weight_range[0]) \ - / (weight_range[1] - weight_range[0]), - 0,1) + (phase_residuals - weight_range[0]) / (weight_range[1] - weight_range[0]), + 0, + 1, + ) ax[self.num_crystals].imshow( im, - vmin = 0, - vmax = 1, - cmap = cmap, + vmin=0, + vmax=1, + cmap=cmap, ) ax[self.num_crystals].set_title( - 'Residuals', - fontsize = 16, + "Residuals", + fontsize=16, ) if not show_ticks: ax[self.num_crystals].set_xticks([]) @@ -974,23 +1024,22 @@ def plot_phase_weights( if returnfig: return fig, ax - def plot_phase_maps( self, - weight_threshold = 0.5, - weight_normalize = True, - total_intensity_normalize = True, - plot_combine = False, - crystal_inds_plot = None, - phase_colors = None, - show_ticks = False, - show_axes = True, - layout = 0, - figsize = (6,6), - return_phase_estimate = False, - return_rgb_images = False, - returnfig = False, - ): + weight_threshold=0.5, + weight_normalize=True, + total_intensity_normalize=True, + plot_combine=False, + crystal_inds_plot=None, + phase_colors=None, + show_ticks=False, + show_axes=True, + layout=0, + figsize=(6, 6), + return_phase_estimate=False, + return_rgb_images=False, + returnfig=False, + ): """ Plot the individual phase weight maps and residuals. @@ -1035,69 +1084,77 @@ def plot_phase_maps( """ if phase_colors is None: - phase_colors = np.array(( - (1.0,0.0,0.0), - (0.0,0.8,1.0), - (0.0,0.8,0.0), - (1.0,0.0,1.0), - (0.0,0.4,1.0), - (1.0,0.8,0.0), - )) + phase_colors = np.array( + ( + (1.0, 0.0, 0.0), + (0.0, 0.8, 1.0), + (0.0, 0.8, 0.0), + (1.0, 0.0, 1.0), + (0.0, 0.4, 1.0), + (1.0, 0.8, 0.0), + ) + ) phase_weights = self.phase_weights.copy() if total_intensity_normalize: sub = self.int_total > 0.0 for a0 in range(self.num_fits): - phase_weights[:,:,a0][sub] /= self.int_total[sub] + phase_weights[:, :, a0][sub] /= self.int_total[sub] # intensity range for plotting if weight_normalize: - scale = np.median(np.max(phase_weights,axis=2)) + scale = np.median(np.max(phase_weights, axis=2)) else: scale = 1 weight_threshold = weight_threshold * scale # init - im_all = np.zeros(( - self.num_crystals, - self.phase_weights.shape[0], - self.phase_weights.shape[1])) - im_rgb_all = np.zeros(( - self.num_crystals, - self.phase_weights.shape[0], - self.phase_weights.shape[1], - 3)) + im_all = np.zeros( + ( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + ) + ) + im_rgb_all = np.zeros( + ( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + 3, + ) + ) # phase weights over threshold for a0 in range(self.num_crystals): - sub = self.crystal_identity[:,0] == a0 - im = np.sum(phase_weights[:,:,sub],axis=2) + sub = self.crystal_identity[:, 0] == a0 + im = np.sum(phase_weights[:, :, sub], axis=2) im_all[a0] = np.maximum(im - weight_threshold, 0) # estimate compositions - im_sum = np.sum(im_all, axis = 0) + im_sum = np.sum(im_all, axis=0) sub = im_sum > 0.0 for a0 in range(self.num_crystals): im_all[a0][sub] /= im_sum[sub] for a1 in range(3): - im_rgb_all[a0,:,:,a1] = im_all[a0] * phase_colors[a0,a1] + im_rgb_all[a0, :, :, a1] = im_all[a0] * phase_colors[a0, a1] if plot_combine: if crystal_inds_plot is None: - im_rgb = np.sum(im_rgb_all, axis = 0) + im_rgb = np.sum(im_rgb_all, axis=0) else: - im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis = 0) + im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis=0) - im_rgb = np.clip(im_rgb,0,1) + im_rgb = np.clip(im_rgb, 0, 1) - fig,ax = plt.subplots(1,1,figsize=figsize) + fig, ax = plt.subplots(1, 1, figsize=figsize) ax.imshow( im_rgb, ) ax.set_title( - 'Phase Maps', - fontsize = 16, + "Phase Maps", + fontsize=16, ) if not show_ticks: ax.set_xticks([]) @@ -1108,24 +1165,26 @@ def plot_phase_maps( else: # plotting if layout == 0: - fig,ax = plt.subplots( + fig, ax = plt.subplots( 1, self.num_crystals, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) elif layout == 1: - fig,ax = plt.subplots( + fig, ax = plt.subplots( self.num_crystals, 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) - + figsize=(figsize[0], (self.num_fits + 1) * figsize[1]), + ) + for a0 in range(self.num_crystals): - + ax[a0].imshow( im_rgb_all[a0], ) ax[a0].set_title( self.crystal_names[a0], - fontsize = 16, + fontsize=16, ) if not show_ticks: ax[a0].set_xticks([]) @@ -1154,20 +1213,19 @@ def plot_phase_maps( if returnfig: return fig, ax - def plot_dominant_phase( self, - use_correlation_scores = False, - reliability_range = (0.0,1.0), - sigma = 0.0, - phase_colors = None, - ticks = True, - figsize = (6,6), - legend_add = True, - legend_fraction = 0.2, - print_fractions = False, - returnfig = True, - ): + use_correlation_scores=False, + reliability_range=(0.0, 1.0), + sigma=0.0, + phase_colors=None, + ticks=True, + figsize=(6, 6), + legend_add=True, + legend_fraction=0.2, + print_fractions=False, + returnfig=True, + ): """ Plot a combined figure showing the primary phase at each probe position. Mask by the reliability index (best match minus 2nd best match). @@ -1201,148 +1259,142 @@ def plot_dominant_phase( Figure and axes handles. """ - + if phase_colors is None: - phase_colors = np.array([ - [1.0,0.9,0.6], - [1,0,0], - [0,0.7,0], - [0,0.7,1], - [1,0,1], - ]) - + phase_colors = np.array( + [ + [1.0, 0.9, 0.6], + [1, 0, 0], + [0, 0.7, 0], + [0, 0.7, 1], + [1, 0, 1], + ] + ) # init arrays scan_shape = self.phase_weights.shape[:2] phase_map = np.zeros(scan_shape) phase_corr = np.zeros(scan_shape) phase_corr_2nd = np.zeros(scan_shape) - phase_sig = np.zeros((self.num_crystals,scan_shape[0],scan_shape[1])) + phase_sig = np.zeros((self.num_crystals, scan_shape[0], scan_shape[1])) if use_correlation_scores: # Calculate scores from highest correlation match for a0 in range(self.num_crystals): phase_sig[a0] = np.maximum( phase_sig[a0], - np.max(self.crystals[a0].orientation_map.corr,axis=2), + np.max(self.crystals[a0].orientation_map.corr, axis=2), ) else: # sum up phase weights by crystal type for a0 in range(self.num_fits): - ind = self.crystal_identity[a0,0] - phase_sig[ind] += self.phase_weights[:,:,a0] + ind = self.crystal_identity[a0, 0] + phase_sig[ind] += self.phase_weights[:, :, a0] # smoothing of the outputs if sigma > 0.0: for a0 in range(self.num_crystals): phase_sig[a0] = gaussian_filter( phase_sig[a0], - sigma = sigma, - mode = 'nearest', - ) + sigma=sigma, + mode="nearest", + ) # find highest correlation score for each crystal and match index for a0 in range(self.num_crystals): - sub = phase_sig[a0] > phase_corr + sub = phase_sig[a0] > phase_corr phase_map[sub] = a0 phase_corr[sub] = phase_sig[a0][sub] - + if self.single_phase: phase_scale = np.clip( - (self.phase_reliability - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + (self.phase_reliability - reliability_range[0]) + / (reliability_range[1] - reliability_range[0]), 0, - 1) + 1, + ) else: # find the second correlation score for each crystal and match index for a0 in range(self.num_crystals): corr = phase_sig[a0].copy() - corr[phase_map==a0] = 0.0 + corr[phase_map == a0] = 0.0 sub = corr > phase_corr_2nd phase_corr_2nd[sub] = corr[sub] - + # Estimate the reliability phase_rel = phase_corr - phase_corr_2nd phase_scale = np.clip( - (phase_rel - reliability_range[0]) / (reliability_range[1] - reliability_range[0]), + (phase_rel - reliability_range[0]) + / (reliability_range[1] - reliability_range[0]), 0, - 1) + 1, + ) # Print the total area of fraction of each phase if print_fractions: phase_mask = phase_scale >= 0.5 phase_total = np.sum(phase_mask) - print('Phase Fractions') - print('---------------') + print("Phase Fractions") + print("---------------") for a0 in range(self.num_crystals): phase_frac = np.sum((phase_map == a0) * phase_mask) / phase_total - print(self.crystal_names[a0] + ' - ' + f'{phase_frac*100:.4f}' + '%') - + print(self.crystal_names[a0] + " - " + f"{phase_frac*100:.4f}" + "%") - self.phase_rgb = np.zeros((scan_shape[0],scan_shape[1],3)) + self.phase_rgb = np.zeros((scan_shape[0], scan_shape[1], 3)) for a0 in range(self.num_crystals): - sub = phase_map==a0 + sub = phase_map == a0 for a1 in range(3): - self.phase_rgb[:,:,a1][sub] = phase_colors[a0,a1] * phase_scale[sub] + self.phase_rgb[:, :, a1][sub] = phase_colors[a0, a1] * phase_scale[sub] # normalize # self.phase_rgb = np.clip( # (self.phase_rgb - rel_range[0]) / (rel_range[1] - rel_range[0]), # 0,1) - - - + fig = plt.figure(figsize=figsize) if legend_add: width = 1 - ax = fig.add_axes((0,legend_fraction,1,1-legend_fraction)) - ax_leg = fig.add_axes((0,0,1,legend_fraction)) + ax = fig.add_axes((0, legend_fraction, 1, 1 - legend_fraction)) + ax_leg = fig.add_axes((0, 0, 1, legend_fraction)) for a0 in range(self.num_crystals): ax_leg.scatter( - a0*width, + a0 * width, 0, - s = 200, - marker = 's', - edgecolor = (0,0,0,1), - facecolor = phase_colors[a0], + s=200, + marker="s", + edgecolor=(0, 0, 0, 1), + facecolor=phase_colors[a0], ) ax_leg.text( - a0*width+0.1, + a0 * width + 0.1, 0, self.crystal_names[a0], - fontsize = 16, - verticalalignment = 'center', + fontsize=16, + verticalalignment="center", + ) + ax_leg.axis("off") + ax_leg.set_xlim( + ( + width * -0.5, + width * (self.num_crystals + 0.5), ) - ax_leg.axis('off') - ax_leg.set_xlim(( - width * -0.5, - width * (self.num_crystals+0.5), - )) + ) else: - ax = fig.add_axes((0,0,1,1)) + ax = fig.add_axes((0, 0, 1, 1)) ax.imshow( self.phase_rgb, - # vmin = 0, - # vmax = 5, - # self.phase_rgb, - # phase_corr - phase_corr_2nd, - # cmap = 'turbo', - # vmin = 0, - # vmax = 3, - # cmap = 'gray', ) if ticks is False: ax.set_xticks([]) ax.set_yticks([]) - if returnfig: - return fig,ax - + return fig, ax diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 649f26faa..56d4520ee 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -848,7 +848,7 @@ def plot_orientation_plan( bragg_peaks = self.generate_diffraction_pattern( orientation_matrix=self.orientation_rotation_matrices[index_plot, :], sigma_excitation_error=self.orientation_kernel_size / 3, - precession_angle_degrees = self.orientation_precession_angle_degrees, + precession_angle_degrees=self.orientation_precession_angle_degrees, ) plot_diffraction_pattern( @@ -949,19 +949,16 @@ def plot_diffraction_pattern( scale_markers_compare: Optional[float] = None, power_markers: float = 1, power_markers_compare: float = 1, - - color = (0.0, 0.0, 0.0), - color_compare = None, - facecolor = None, - facecolor_compare = (0.0, 0.7, 1.0), - edgecolor = None, - edgecolor_compare = None, - linewidth = 1, - linewidth_compare = 1, - - marker = '+', - marker_compare = 'o', - + color=(0.0, 0.0, 0.0), + color_compare=None, + facecolor=None, + facecolor_compare=(0.0, 0.7, 1.0), + edgecolor=None, + edgecolor_compare=None, + linewidth=1, + linewidth_compare=1, + marker="+", + marker_compare="o", plot_range_kx_ky: Optional[Union[list, tuple, np.ndarray]] = None, add_labels: bool = True, shift_labels: float = 0.08, @@ -1004,9 +1001,7 @@ def plot_diffraction_pattern( if power_markers == 2: marker_size = scale_markers * bragg_peaks.data["intensity"] else: - marker_size = scale_markers * ( - bragg_peaks.data["intensity"] ** power_markers - ) + marker_size = scale_markers * (bragg_peaks.data["intensity"] ** power_markers) # Apply marker size limits to primary plot marker_size = np.clip(marker_size, min_marker_size, max_marker_size) @@ -1061,7 +1056,6 @@ def plot_diffraction_pattern( ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) - if plot_range_kx_ky is not None: plot_range_kx_ky = np.array(plot_range_kx_ky) if plot_range_kx_ky.ndim == 0: diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index ea83cc866..34630dbc2 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -17,8 +17,8 @@ def fit_amorphous_ring( fit_all_images=False, maxfev=None, robust=False, - robust_steps = 3, - robust_thresh = 1.0, + robust_steps=3, + robust_thresh=1.0, verbose=False, plot_result=True, plot_log_scale=False, @@ -229,7 +229,7 @@ def fit_amorphous_ring( if maxfev is None: coefs = curve_fit( amorphous_model, - basis[:,sub_fit], + basis[:, sub_fit], vals[sub_fit] / int_mean, p0=coefs, xtol=1e-8, @@ -238,7 +238,7 @@ def fit_amorphous_ring( else: coefs = curve_fit( amorphous_model, - basis[:,sub_fit], + basis[:, sub_fit], vals[sub_fit] / int_mean, p0=coefs, xtol=1e-8, @@ -250,7 +250,6 @@ def fit_amorphous_ring( # Scale intensity coefficients coefs[5:8] *= int_mean - # Perform the fit on each individual diffration pattern if fit_all_images: coefs_all = np.zeros((datacube.shape[0], datacube.shape[1], coefs.size)) From 88994879694c43680926ea7d9bddb649edf75412 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 17 Jul 2024 20:18:04 -0700 Subject: [PATCH 57/64] black formatting, updated variable name --- .../diffraction/WK_scattering_factors.py | 4 +- py4DSTEM/process/diffraction/crystal_ACOM.py | 1 - py4DSTEM/process/diffraction/crystal_phase.py | 68 ++++++++++--------- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 ++++---- .../process/phase/magnetic_ptychography.py | 28 ++++---- py4DSTEM/process/phase/parallax.py | 24 +++---- .../process/phase/ptychographic_methods.py | 4 +- .../process/phase/ptychographic_tomography.py | 28 ++++---- py4DSTEM/process/phase/utils.py | 4 +- py4DSTEM/process/utils/utils.py | 18 ++++- 12 files changed, 117 insertions(+), 96 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index eb964de96..70110a977 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,7 +221,9 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( + BJ / (BI + BJ) + ) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index ae71c20bc..f7a4a20db 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1052,7 +1052,6 @@ def match_single_pattern( sub = dqr < self.orientation_kernel_size if np.any(sub): - im_polar[ind_radial, :] = np.sum( np.power( np.maximum(intensity[sub, None], 0.0), diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 3db718555..c161304f9 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -10,7 +10,7 @@ from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -class Crystal_Phase: +class CrystalPhase: """ A class storing multiple crystal structures, and associated diffraction data. Must be initialized after matching orientations to a pointlistarray??? @@ -194,8 +194,8 @@ def quantify_single_pattern( """ - # tolerance - tol2 = 1e-6 + # tolerance for separating the origin peak. + tolerance_origin_2 = 1e-6 # calibrations center = pointlistarray.calstate["center"] @@ -231,11 +231,15 @@ def quantify_single_pattern( ) # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() if k_max is None: - keep = bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2 + keep = ( + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 + > tolerance_origin_2 + ) else: keep = np.logical_and.reduce( ( - bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2, + bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 + > tolerance_origin_2, np.abs(bragg_peaks.data["qx"]) < k_max, np.abs(bragg_peaks.data["qy"]) < k_max, ) @@ -303,14 +307,14 @@ def quantify_single_pattern( if k_max is None: del_peak = ( bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2 - < tol2 + < tolerance_origin_2 ) else: del_peak = np.logical_or.reduce( ( bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2 - < tol2, + < tolerance_origin_2, np.abs(bragg_peaks_fit.data["qx"]) > k_max, np.abs(bragg_peaks_fit.data["qy"]) > k_max, ) @@ -473,7 +477,6 @@ def quantify_single_pattern( search = True while search is True: - basis_solve = basis[:, inds_solve] obs_solve = obs.copy() @@ -685,7 +688,6 @@ def quantify_single_pattern( crystal_inds_plot == None or np.min(np.abs(c - crystal_inds_plot)) == 0 ): - qx_fit = library_peaks[a0].data["qx"] qy_fit = library_peaks[a0].data["qy"] @@ -701,7 +703,6 @@ def quantify_single_pattern( matches_fit = library_matches[a0] if plot_only_nonzero_phases is False or phase_weights[a0] > 0: - # if np.mod(m,2) == 0: ax.scatter( qy_fit[matches_fit], @@ -877,27 +878,30 @@ def quantify_phase( disable=not progress_bar, ): # calculate phase weights - phase_weights, phase_residual, phase_reliability, int_peaks = ( - self.quantify_single_pattern( - pointlistarray=pointlistarray, - xy_position=(rx, ry), - corr_kernel_size=corr_kernel_size, - sigma_excitation_error=sigma_excitation_error, - precession_angle_degrees=precession_angle_degrees, - power_intensity=power_intensity, - power_intensity_experiment=power_intensity_experiment, - k_max=k_max, - max_number_patterns=max_number_patterns, - single_phase=single_phase, - allow_strain=allow_strain, - strain_iterations=strain_iterations, - strain_max=strain_max, - include_false_positives=include_false_positives, - weight_false_positives=weight_false_positives, - plot_result=False, - verbose=False, - returnfig=False, - ) + ( + phase_weights, + phase_residual, + phase_reliability, + int_peaks, + ) = self.quantify_single_pattern( + pointlistarray=pointlistarray, + xy_position=(rx, ry), + corr_kernel_size=corr_kernel_size, + sigma_excitation_error=sigma_excitation_error, + precession_angle_degrees=precession_angle_degrees, + power_intensity=power_intensity, + power_intensity_experiment=power_intensity_experiment, + k_max=k_max, + max_number_patterns=max_number_patterns, + single_phase=single_phase, + allow_strain=allow_strain, + strain_iterations=strain_iterations, + strain_max=strain_max, + include_false_positives=include_false_positives, + weight_false_positives=weight_false_positives, + plot_result=False, + verbose=False, + returnfig=False, ) self.phase_weights[rx, ry] = phase_weights self.phase_residuals[rx, ry] = phase_residual @@ -1178,7 +1182,6 @@ def plot_phase_maps( ) for a0 in range(self.num_crystals): - ax[a0].imshow( im_rgb_all[a0], ) @@ -1315,7 +1318,6 @@ def plot_dominant_phase( ) else: - # find the second correlation score for each crystal and match index for a0 in range(self.num_crystals): corr = phase_sig[a0].copy() diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 56d4520ee..1663a6198 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,7 +454,8 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) + * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 5c2d56a3c..9973ff79f 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,7 +169,8 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index be6332c74..9ffcc6212 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1196,20 +1196,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 19f306188..0949a02eb 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1497,20 +1497,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[batch_indices] = ( - self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index e12d6a133..3d8c09635 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2202,16 +2202,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 0] - ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 1] - ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2222,16 +2222,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 0] - ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 1] - ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 283ddb1ba..cbc0b2fde 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -356,7 +356,9 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 3639096dc..e1dc33abb 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1088,20 +1088,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[batch_indices] = ( - self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, - ) + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 5742ff7e7..1b12da6a0 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,7 +203,9 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) + return xp.exp( + -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 + ) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index ddeeb2c36..60da616d1 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,7 +93,12 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) return lam @@ -102,8 +107,15 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 - sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) + sigma = ( + (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + ) return sigma From 0ecf427c287b217d635b6200e436c3e61782c22c Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 17 Jul 2024 20:20:39 -0700 Subject: [PATCH 58/64] updating black version --- .../diffraction/WK_scattering_factors.py | 4 +-- py4DSTEM/process/diffraction/crystal_viz.py | 3 +- py4DSTEM/process/fit/fit.py | 3 +- .../magnetic_ptychographic_tomography.py | 28 +++++++++---------- .../process/phase/magnetic_ptychography.py | 28 +++++++++---------- py4DSTEM/process/phase/parallax.py | 24 ++++++++-------- .../process/phase/ptychographic_methods.py | 4 +-- .../process/phase/ptychographic_tomography.py | 28 +++++++++---------- py4DSTEM/process/phase/utils.py | 4 +-- py4DSTEM/process/utils/utils.py | 18 ++---------- 10 files changed, 62 insertions(+), 82 deletions(-) diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index 70110a977..eb964de96 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,9 +221,7 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( - BJ / (BI + BJ) - ) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 1663a6198..56d4520ee 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -454,8 +454,7 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) - * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 9973ff79f..5c2d56a3c 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,8 +169,7 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) - + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 9ffcc6212..be6332c74 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1196,20 +1196,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 0949a02eb..19f306188 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1497,20 +1497,20 @@ def reconstruct( # position correction if not fix_positions and a0 > 0: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - self._object, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 3d8c09635..e12d6a133 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2202,16 +2202,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 0] + measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 0] + ) measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 1] + measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 1] + ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2222,16 +2222,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 0] + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 0] + ) fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 1] + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 1] + ) max_shift = xp.max( xp.array( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index cbc0b2fde..283ddb1ba 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -356,9 +356,7 @@ def _precompute_propagator_arrays( propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) + propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) if theta_x is not None: propagators[i] *= xp.exp( diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index e1dc33abb..3639096dc 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -1088,20 +1088,20 @@ def reconstruct( # position correction if not fix_positions: - self._positions_px_all[ - batch_indices - ] = self._position_correction( - object_sliced, - vectorized_patch_indices_row, - vectorized_patch_indices_col, - shifted_probes, - overlap, - amplitudes_device, - positions_px, - positions_px_initial, - positions_step_size, - max_position_update_distance, - max_position_total_distance, + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) ) measurement_error += batch_error diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 1b12da6a0..5742ff7e7 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -203,9 +203,7 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp( - -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 - ) + return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 60da616d1..ddeeb2c36 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -93,12 +93,7 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 return lam @@ -107,15 +102,8 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) - sigma = ( - (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) return sigma From 7c2eb0419e25eaee2d5186038b6188aa6a27e87f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 18 Jul 2024 13:27:20 -0700 Subject: [PATCH 59/64] switch to centered alignment bins by default --- py4DSTEM/process/phase/parallax.py | 5 ++++- py4DSTEM/process/phase/parameter_optimize.py | 16 +++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index e12d6a133..fa86f4dc2 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -743,6 +743,7 @@ def reconstruct( min_alignment_bin: int = 1, num_iter_at_min_bin: int = 2, alignment_bin_values: list = None, + centered_alignment_bins: bool = True, cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = False, @@ -879,6 +880,8 @@ def reconstruct( (bin_vals, np.repeat(bin_vals[-1], num_iter_at_min_bin - 1)) ) + bin_shift = 0 if centered_alignment_bins else 0.5 + if plot_aligned_bf: num_plots = bin_vals.shape[0] nrows = int(np.sqrt(num_plots)) @@ -913,7 +916,7 @@ def reconstruct( G_ref = xp.fft.fft2(self._recon_BF) # Segment the virtual images with current binning values - xy_inds = xp.round(xy_center / bin_vals[a0] + 0.5).astype("int") + xy_inds = xp.round(xy_center / bin_vals[a0] + bin_shift).astype("int") xy_vals = np.unique( asnumpy(xy_inds), axis=0 ) # axis is not yet supported in cupy diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index fc1e84f11..e51d3cd2c 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -458,13 +458,15 @@ def _split_static_and_optimization_vars(self, argdict): return static_args, optimization_args def _get_scan_positions(self, affine_transform, dataset): - R_pixel_size = dataset.calibration.get_R_pixel_size() - x, y = ( - np.arange(dataset.R_Nx) * R_pixel_size, - np.arange(dataset.R_Ny) * R_pixel_size, - ) - x, y = np.meshgrid(x, y, indexing="ij") - scan_positions = np.stack((x.ravel(), y.ravel()), axis=1) + scan_positions = self._init_static_args.pop("initial_scan_positions",None) + if scan_positions is None: + R_pixel_size = dataset.calibration.get_R_pixel_size() + x, y = ( + np.arange(dataset.R_Nx) * R_pixel_size, + np.arange(dataset.R_Ny) * R_pixel_size, + ) + x, y = np.meshgrid(x, y, indexing="ij") + scan_positions = np.stack((x.ravel(), y.ravel()), axis=1) scan_positions = scan_positions @ affine_transform.asarray() return scan_positions From 394c0daa12f2ef9b45edb76366354bb8e8a0813d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 18 Jul 2024 13:28:59 -0700 Subject: [PATCH 60/64] forgot to black small BO change --- py4DSTEM/process/phase/parameter_optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index e51d3cd2c..149e8143b 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -458,7 +458,7 @@ def _split_static_and_optimization_vars(self, argdict): return static_args, optimization_args def _get_scan_positions(self, affine_transform, dataset): - scan_positions = self._init_static_args.pop("initial_scan_positions",None) + scan_positions = self._init_static_args.pop("initial_scan_positions", None) if scan_positions is None: R_pixel_size = dataset.calibration.get_R_pixel_size() x, y = ( From 564999ced7732db821b8d628e466be52c89d684e Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:07:42 -0700 Subject: [PATCH 61/64] testing fix for drawing legend --- py4DSTEM/process/diffraction/crystal_viz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 56d4520ee..ab4b259f1 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -1416,8 +1416,8 @@ def plot_orientation_maps( # Triangulate faces p = self.orientation_vecs[:, (1, 0, 2)] tri = mtri.Triangulation( - self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, - self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, + self.orientation_inds[:, 1].astype('float') - self.orientation_inds[:, 0].astype('float') * 1e-3, + self.orientation_inds[:, 0].astype('float') - self.orientation_inds[:, 1].astype('float') * 1e-3, ) # convert rgb values of pixels to faces rgb_faces = ( From fece355625672eb90dd2a8e987fd668d46b91be2 Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:21:35 -0700 Subject: [PATCH 62/64] Revert "testing fix for drawing legend" This reverts commit 564999ced7732db821b8d628e466be52c89d684e. --- py4DSTEM/process/diffraction/crystal_viz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index ab4b259f1..56d4520ee 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -1416,8 +1416,8 @@ def plot_orientation_maps( # Triangulate faces p = self.orientation_vecs[:, (1, 0, 2)] tri = mtri.Triangulation( - self.orientation_inds[:, 1].astype('float') - self.orientation_inds[:, 0].astype('float') * 1e-3, - self.orientation_inds[:, 0].astype('float') - self.orientation_inds[:, 1].astype('float') * 1e-3, + self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, + self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, ) # convert rgb values of pixels to faces rgb_faces = ( From a70970588fd1322f17feefb52ecff7b0f8458462 Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:38:54 -0700 Subject: [PATCH 63/64] adding option to hide legend --- py4DSTEM/process/diffraction/crystal_viz.py | 437 ++++++++++---------- 1 file changed, 221 insertions(+), 216 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 56d4520ee..88cc16853 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -538,6 +538,7 @@ def plot_orientation_zones( proj_dir_cartesian: Optional[Union[list, tuple, np.ndarray]] = None, tol_den=10, marker_size: float = 20, + plot_zone_axis_labels: bool = True, plot_limit: Union[list, tuple, np.ndarray] = np.array([-1.1, 1.1]), figsize: Union[list, tuple, np.ndarray] = (8, 8), returnfig: bool = False, @@ -553,6 +554,7 @@ def plot_orientation_zones( dir_proj (float): projection direction, either [elev azim] or normal vector Default is mean vector of self.orientation_zone_axis_range rows marker_size (float): size of markers + plot_zone_axis_labels (bool): plot the zone axis labels plot_limit (float): x y z plot limits, default is [0, 1.05] figsize (2 element float): size scaling of figure axes returnfig (bool): set to True to return figure and axes handles @@ -727,47 +729,47 @@ def plot_orientation_zones( zorder=10, ) - text_scale_pos = 1.2 - text_params = { - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "color": "k", - "size": 16, - } - # 'ha': 'center', - - ax.text( - self.orientation_vecs[inds[0], 1] * text_scale_pos, - self.orientation_vecs[inds[0], 0] * text_scale_pos, - self.orientation_vecs[inds[0], 2] * text_scale_pos, - label_0, - None, - zorder=11, - ha="center", - **text_params, - ) - if self.orientation_full is False and self.orientation_half is False: - ax.text( - self.orientation_vecs[inds[1], 1] * text_scale_pos, - self.orientation_vecs[inds[1], 0] * text_scale_pos, - self.orientation_vecs[inds[1], 2] * text_scale_pos, - label_1, - None, - zorder=12, - ha="center", - **text_params, - ) + if plot_zone_axis_labels: + text_scale_pos = 1.2 + text_params = { + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 16, + } + # 'ha': 'center', ax.text( - self.orientation_vecs[inds[2], 1] * text_scale_pos, - self.orientation_vecs[inds[2], 0] * text_scale_pos, - self.orientation_vecs[inds[2], 2] * text_scale_pos, - label_2, + self.orientation_vecs[inds[0], 1] * text_scale_pos, + self.orientation_vecs[inds[0], 0] * text_scale_pos, + self.orientation_vecs[inds[0], 2] * text_scale_pos, + label_0, None, - zorder=13, + zorder=11, ha="center", **text_params, ) + if self.orientation_full is False and self.orientation_half is False: + ax.text( + self.orientation_vecs[inds[1], 1] * text_scale_pos, + self.orientation_vecs[inds[1], 0] * text_scale_pos, + self.orientation_vecs[inds[1], 2] * text_scale_pos, + label_1, + None, + zorder=12, + ha="center", + **text_params, + ) + ax.text( + self.orientation_vecs[inds[2], 1] * text_scale_pos, + self.orientation_vecs[inds[2], 0] * text_scale_pos, + self.orientation_vecs[inds[2], 2] * text_scale_pos, + label_2, + None, + zorder=13, + ha="center", + **text_params, + ) # ax.scatter( # xs=self.g_vec_all[0,:], @@ -1118,6 +1120,7 @@ def plot_orientation_maps( dir_in_plane_degrees: float = 0.0, corr_range: np.ndarray = np.array([0, 5]), corr_normalize: bool = True, + show_legend: bool = True, scale_legend: bool = None, figsize: Union[list, tuple, np.ndarray] = (16, 5), figbound: Union[list, tuple, np.ndarray] = (0.01, 0.005), @@ -1139,6 +1142,7 @@ def plot_orientation_maps( dir_in_plane_degrees (float): In-plane angle to plot in degrees. Default is 0 / x-axis / vertical down. corr_range (np.ndarray): Correlation intensity range for the plot corr_normalize (bool): If true, set mean correlation to 1. + show_legend (bool): Show the legend scale_legend (float): 2 elements, x and y scaling of legend panel figsize (array): 2 elements defining figure size figbound (array): 2 elements defining figure boundary @@ -1413,189 +1417,190 @@ def plot_orientation_maps( ax_x.axis("off") ax_z.axis("off") - # Triangulate faces - p = self.orientation_vecs[:, (1, 0, 2)] - tri = mtri.Triangulation( - self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, - self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, - ) - # convert rgb values of pixels to faces - rgb_faces = ( - rgb_legend[tri.triangles[:, 0], :] - + rgb_legend[tri.triangles[:, 1], :] - + rgb_legend[tri.triangles[:, 2], :] - ) / 3 - # Add triangulated surface plot to axes - pc = art3d.Poly3DCollection( - p[tri.triangles], - facecolors=rgb_faces, - alpha=1, - ) - pc.set_antialiased(False) - ax_l.add_collection(pc) - - if plot_limit is None: - plot_limit = np.array( - [ - [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])], - [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])], - ] + if show_legend: + # Triangulate faces + p = self.orientation_vecs[:, (1, 0, 2)] + tri = mtri.Triangulation( + self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3, + self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3, ) - # plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.5 + np.mean( - # plot_limit, axis=0 - # ) - plot_limit[:, 0] = ( - plot_limit[:, 0] - np.mean(plot_limit[:, 0]) - ) * 1.5 + np.mean(plot_limit[:, 0]) - plot_limit[:, 1] = ( - plot_limit[:, 2] - np.mean(plot_limit[:, 1]) - ) * 1.5 + np.mean(plot_limit[:, 1]) - plot_limit[:, 2] = ( - plot_limit[:, 1] - np.mean(plot_limit[:, 2]) - ) * 1.1 + np.mean(plot_limit[:, 2]) - - # ax_l.view_init(elev=el, azim=az) - # Appearance - ax_l.invert_yaxis() - if swap_axes_xy_limits: - ax_l.axes.set_xlim3d(left=plot_limit[0, 0], right=plot_limit[1, 0]) - ax_l.axes.set_ylim3d(bottom=plot_limit[0, 1], top=plot_limit[1, 1]) - ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) - else: - ax_l.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1]) - ax_l.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0]) - ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) - axisEqual3D(ax_l) - if camera_dist is not None: - ax_l.dist = camera_dist - ax_l.axis("off") - - # Add text labels - text_scale_pos = 0.1 - text_params = { - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "color": "k", - "size": 14, - } - format_labels = "{0:.2g}" - vec = self.orientation_vecs[inds_legend[0], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_0[0]) - + " " - + format_labels.format(label_0[1]) - + " " - + format_labels.format(label_0[2]) - + "]", - None, - zorder=11, - ha="center", - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_0[0]) - + " " - + format_labels.format(label_0[1]) - + " " - + format_labels.format(label_0[2]) - + " " - + format_labels.format(label_0[3]) - + "]", - None, - zorder=11, - ha="center", - **text_params, - ) - vec = self.orientation_vecs[inds_legend[1], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_1[0]) - + " " - + format_labels.format(label_1[1]) - + " " - + format_labels.format(label_1[2]) - + "]", - None, - zorder=12, - ha=ha_1, - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_1[0]) - + " " - + format_labels.format(label_1[1]) - + " " - + format_labels.format(label_1[2]) - + " " - + format_labels.format(label_1[3]) - + "]", - None, - zorder=12, - ha=ha_1, - **text_params, - ) - vec = self.orientation_vecs[inds_legend[2], :] - cam_dir - vec = vec / np.linalg.norm(vec) - if np.abs(self.cell[5] - 120.0) > 1e-6: - ax_l.text( - self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_2[0]) - + " " - + format_labels.format(label_2[1]) - + " " - + format_labels.format(label_2[2]) - + "]", - None, - zorder=13, - ha=ha_2, - **text_params, - ) - else: - ax_l.text( - self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, - self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, - "[" - + format_labels.format(label_2[0]) - + " " - + format_labels.format(label_2[1]) - + " " - + format_labels.format(label_2[2]) - + " " - + format_labels.format(label_2[3]) - + "]", - None, - zorder=13, - ha=ha_2, - **text_params, + # convert rgb values of pixels to faces + rgb_faces = ( + rgb_legend[tri.triangles[:, 0], :] + + rgb_legend[tri.triangles[:, 1], :] + + rgb_legend[tri.triangles[:, 2], :] + ) / 3 + # Add triangulated surface plot to axes + pc = art3d.Poly3DCollection( + p[tri.triangles], + facecolors=rgb_faces, + alpha=1, ) + pc.set_antialiased(False) + ax_l.add_collection(pc) + + if plot_limit is None: + plot_limit = np.array( + [ + [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])], + [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])], + ] + ) + # plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.5 + np.mean( + # plot_limit, axis=0 + # ) + plot_limit[:, 0] = ( + plot_limit[:, 0] - np.mean(plot_limit[:, 0]) + ) * 1.5 + np.mean(plot_limit[:, 0]) + plot_limit[:, 1] = ( + plot_limit[:, 2] - np.mean(plot_limit[:, 1]) + ) * 1.5 + np.mean(plot_limit[:, 1]) + plot_limit[:, 2] = ( + plot_limit[:, 1] - np.mean(plot_limit[:, 2]) + ) * 1.1 + np.mean(plot_limit[:, 2]) + + # ax_l.view_init(elev=el, azim=az) + # Appearance + ax_l.invert_yaxis() + if swap_axes_xy_limits: + ax_l.axes.set_xlim3d(left=plot_limit[0, 0], right=plot_limit[1, 0]) + ax_l.axes.set_ylim3d(bottom=plot_limit[0, 1], top=plot_limit[1, 1]) + ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) + else: + ax_l.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1]) + ax_l.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0]) + ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2]) + axisEqual3D(ax_l) + if camera_dist is not None: + ax_l.dist = camera_dist + ax_l.axis("off") + + # Add text labels + text_scale_pos = 0.1 + text_params = { + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 14, + } + format_labels = "{0:.2g}" + vec = self.orientation_vecs[inds_legend[0], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_0[0]) + + " " + + format_labels.format(label_0[1]) + + " " + + format_labels.format(label_0[2]) + + "]", + None, + zorder=11, + ha="center", + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_0[0]) + + " " + + format_labels.format(label_0[1]) + + " " + + format_labels.format(label_0[2]) + + " " + + format_labels.format(label_0[3]) + + "]", + None, + zorder=11, + ha="center", + **text_params, + ) + vec = self.orientation_vecs[inds_legend[1], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_1[0]) + + " " + + format_labels.format(label_1[1]) + + " " + + format_labels.format(label_1[2]) + + "]", + None, + zorder=12, + ha=ha_1, + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_1[0]) + + " " + + format_labels.format(label_1[1]) + + " " + + format_labels.format(label_1[2]) + + " " + + format_labels.format(label_1[3]) + + "]", + None, + zorder=12, + ha=ha_1, + **text_params, + ) + vec = self.orientation_vecs[inds_legend[2], :] - cam_dir + vec = vec / np.linalg.norm(vec) + if np.abs(self.cell[5] - 120.0) > 1e-6: + ax_l.text( + self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_2[0]) + + " " + + format_labels.format(label_2[1]) + + " " + + format_labels.format(label_2[2]) + + "]", + None, + zorder=13, + ha=ha_2, + **text_params, + ) + else: + ax_l.text( + self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos, + self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos, + "[" + + format_labels.format(label_2[0]) + + " " + + format_labels.format(label_2[1]) + + " " + + format_labels.format(label_2[2]) + + " " + + format_labels.format(label_2[3]) + + "]", + None, + zorder=13, + ha=ha_2, + **text_params, + ) - plt.show() + plt.show() images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2)) if self.pymatgen_available: From f7bde15b2dc97b9a619ff572e8bb3f7e6bf5be13 Mon Sep 17 00:00:00 2001 From: cophus Date: Thu, 18 Jul 2024 15:59:38 -0700 Subject: [PATCH 64/64] option to hide legend axes --- py4DSTEM/process/diffraction/crystal_viz.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 88cc16853..402f6f964 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -1601,6 +1601,8 @@ def plot_orientation_maps( ) plt.show() + else: + ax_l.set_axis_off() images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2)) if self.pymatgen_available: