diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index fda327b53..fb064c7a0 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,7 +3,7 @@ _emd_hook = True from py4DSTEM import is_package_lite -from py4DSTEM.process.phase.direct_ptychography import OBFPtychography, PhaseCompensatedSSBPtychography, SSBPtychography, WDDPtychography +from py4DSTEM.process.phase.direct_ptychography import OBF, SSB, WDD from py4DSTEM.process.phase.dpc import DPC from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography diff --git a/py4DSTEM/process/phase/direct_ptychography.py b/py4DSTEM/process/phase/direct_ptychography.py index 610e0bed7..2a8200c67 100644 --- a/py4DSTEM/process/phase/direct_ptychography.py +++ b/py4DSTEM/process/phase/direct_ptychography.py @@ -1443,20 +1443,12 @@ def sampling(self): ) -class SSBPtychography( +class SSB( DirectPtychography, ): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if sum(self._polar_parameters.values()) != 0.0: - warnings.warn( - "Note aberrations will effectively be ignored in this class", - UserWarning, - ) - - def aberration_fit(self, *args, **kwargs): - raise NotImplementedError() def _reconstruct_single_frequency( self, @@ -1470,14 +1462,16 @@ def _reconstruct_single_frequency( aperture, probe_kwargs, trotter_sign, + phase_compensation: bool = True, + virtual_detector_masks: Sequence[np.ndarray] = None, xp=np, ): """ """ + threshold = 1e-3 G = xp.asarray(intensities_FFT) if Qx == 0.0 and Qy == 0.0: return xp.abs(G).sum() else: - Kx_plus_Qx = Kx + Qx Ky_plus_Qy = Ky + Qy @@ -1486,7 +1480,7 @@ def _reconstruct_single_frequency( force_spatial_frequencies=(Kx_plus_Qx, Ky_plus_Qy), ) alpha_plus, phi_plus = cmplx_probe_plus.get_scattering_angles() - aperture_plus = cmplx_probe_plus.evaluate_aperture(alpha_plus, phi_plus) > 0 + aperture_plus = cmplx_probe_plus.evaluate_aperture(alpha_plus, phi_plus) Kx_minus_Qx = Kx - Qx Ky_minus_Qy = Ky - Qy @@ -1496,240 +1490,59 @@ def _reconstruct_single_frequency( force_spatial_frequencies=(Kx_minus_Qx, Ky_minus_Qy), ) alpha_minus, phi_minus = cmplx_probe_minus.get_scattering_angles() - aperture_minus = ( - cmplx_probe_minus.evaluate_aperture(alpha_minus, phi_minus) > 0 - ) + aperture_minus = cmplx_probe_minus.evaluate_aperture(alpha_minus, phi_minus) - if trotter_sign == "+": - aperture_plus_solo = xp.logical_and( - xp.logical_and(aperture, aperture_plus), ~aperture_minus + if phase_compensation: + aberrations_plus = cmplx_probe_plus.evaluate_aberrations( + alpha_plus, phi_plus ) - return G[aperture_plus_solo].sum().conj() - elif trotter_sign == "-": - aperture_minus_solo = xp.logical_and( - xp.logical_and(aperture, aperture_minus), ~aperture_plus + aberrations_minus = cmplx_probe_minus.evaluate_aberrations( + alpha_minus, phi_minus ) - return G[aperture_minus_solo].sum() - else: - raise ValueError() - - def reconstruct( - self, - trotter_sign="-", - num_jobs=None, - threads_per_job=None, - progress_bar: bool = True, - device: str = None, - clear_fft_cache: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - trotter_sign: str, optional - Sign of single-side trotter to use. One of '+','-'. - num_jobs: int, optional - Number of processes to use. If None, then as many processes as CPUs on - the system will be spawned. - threads_per_job: int, optional - Number of threads to use to avoid over-subscribing when using multiple processors. - worker_pool: WorkerPool, optional - If not None, reconstruction is dispatched to mpire WorkerPool instance. - progress_bar: bool, optional - If True, reconstruction progress is displayed - - Returns - -------- - self: DirectPtychography - Self to accommodate chaining - """ - - xp = self._xp - asnumpy = self._asnumpy - # handle device/storage - if device == "gpu": - warnings.warn( - "Note this class is not very well optimized on gpu.", - UserWarning, - ) + probe_plus = aperture_plus * aberrations_plus + probe_minus = aperture_minus * aberrations_minus + gamma = probe_conj * probe_minus - probe * probe_plus.conj() - if device is not None: - attrs = [ - "_fourier_probe_initial", - ] - self.copy_attributes_to_device(attrs, device) - self.set_device(device, clear_fft_cache) - - sx, sy = self._grid_scan_shape - psi = xp.empty((sx, sy), dtype=xp.complex64) - probe_conj = xp.conj(self._fourier_probe) - probe_kwargs = { - "energy": self._energy, - "gpts": self._intensities_shape, - "sampling": self.sampling, - "semiangle_cutoff": self._semiangle_cutoff, - "vacuum_probe_intensity": self._vacuum_probe_intensity, - "rolloff": self._rolloff, - "device": self._device, - } - - Kx, Ky = self._spatial_frequencies - Qx, Qy = self._scan_frequencies - - cmplx_probe = ComplexProbe( - **probe_kwargs, - force_spatial_frequencies=(Kx, Ky), - ) - - alpha, phi = cmplx_probe.get_scattering_angles() - aperture = cmplx_probe.evaluate_aperture(alpha, phi) > 0 - - # main loop - - if num_jobs == 1: - for ind_x, ind_y in tqdmnd( - sx, - sy, - desc="Reconstructing object", - unit="freq.", - disable=not progress_bar, - ): - psi[ind_x, ind_y] = self._reconstruct_single_frequency( - self._intensities_FFT[ind_x, ind_y], - Qx[ind_x, ind_y], - Qy[ind_x, ind_y], - Kx, - Ky, - self._fourier_probe, - probe_conj, - aperture, - probe_kwargs, - trotter_sign=trotter_sign, - xp=xp, - ) - else: - - if self._device == "gpu": - raise NotImplementedError() - - from mpire import WorkerPool, cpu_count - from threadpoolctl import threadpool_limits - - num_jobs = num_jobs or cpu_count() - - if threads_per_job is not None: - num_jobs = num_jobs // threads_per_job - - map_inputs = [ - { - "intensities_FFT": self._intensities_FFT[ind_x, ind_y], - "Qx": Qx[ind_x, ind_y], - "Qy": Qy[ind_x, ind_y], - } - for ind_x in range(sx) - for ind_y in range(sy) - ] - - def wrapper_function(**kwargs): - with threadpool_limits(limits=threads_per_job): - return self._reconstruct_single_frequency( - **kwargs, - Kx=Kx, - Ky=Ky, - probe=self._fourier_probe, - probe_conj=probe_conj, - aperture=aperture, - probe_kwargs=probe_kwargs, - trotter_sign=trotter_sign, - xp=xp, + if virtual_detector_masks is not None: + gamma = mask_array_using_virtual_detectors( + gamma, virtual_detector_masks, in_place=True ) - with WorkerPool(n_jobs=num_jobs) as pool: - flat_results = pool.map( - wrapper_function, map_inputs, progress_bar=progress_bar - ) - - for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results): - psi[ind_x, ind_y] = res - - self._object = xp.fft.ifft2(psi) / self._mean_diffraction_intensity - - # store result - self.object = asnumpy(self._object) - self.clear_device_mem(self._device, self._clear_fft_cache) - - return self - - -class PhaseCompensatedSSBPtychography( - DirectPtychography, -): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _reconstruct_single_frequency( - self, - intensities_FFT, - Qx, - Qy, - Kx, - Ky, - probe, - probe_conj, - probe_kwargs, - trotter_sign, - virtual_detector_masks: Sequence[np.ndarray] = None, - xp=np, - ): - """ """ - threshold = 1e-3 - - G = xp.asarray(intensities_FFT) - if Qx == 0.0 and Qy == 0.0: - return xp.abs(G).sum() - else: - Kx_plus_Qx = Kx + Qx - Ky_plus_Qy = Ky + Qy - - probe_plus = ComplexProbe( - **probe_kwargs, - force_spatial_frequencies=(Kx_plus_Qx, Ky_plus_Qy), - )._evaluate_ctf() - - Kx_minus_Qx = Kx - Qx - Ky_minus_Qy = Ky - Qy + gamma_abs = xp.abs(gamma) + gamma_ind = gamma_abs > threshold + normalization = gamma_abs[gamma_ind] - probe_minus = ComplexProbe( - **probe_kwargs, - force_spatial_frequencies=(Kx_minus_Qx, Ky_minus_Qy), - )._evaluate_ctf() - - gamma = probe_conj * probe_minus - probe * probe_plus.conj() - - if virtual_detector_masks is not None: - gamma = mask_array_using_virtual_detectors( - gamma, virtual_detector_masks, in_place=True - ) + if trotter_sign == "+": + numerator = -G[gamma_ind].conj() * gamma[gamma_ind] + elif trotter_sign == "-": + numerator = G[gamma_ind] * gamma[gamma_ind].conj() + else: + raise ValueError() - gamma_abs = xp.abs(gamma) - gamma_ind = gamma_abs > threshold + return (numerator / normalization).sum() - normalization = gamma_abs[gamma_ind] - if trotter_sign == "+": - numerator = -G[gamma_ind].conj() * gamma[gamma_ind] - elif trotter_sign == "-": - numerator = G[gamma_ind] * gamma[gamma_ind].conj() else: - raise ValueError() + aperture_plus = aperture_plus > threshold + aperture_minus = aperture_minus > threshold - return (numerator / normalization).sum() + if trotter_sign == "+": + aperture_solo = xp.logical_and( + xp.logical_and(aperture, aperture_plus), ~aperture_minus + ) + return G[aperture_solo].sum().conj() + elif trotter_sign == "-": + aperture_solo = xp.logical_and( + xp.logical_and(aperture, aperture_minus), ~aperture_plus + ) + return G[aperture_solo].sum() + else: + raise ValueError() def reconstruct( self, trotter_sign="-", + phase_compensation=True, num_jobs=None, threads_per_job=None, virtual_detector_masks: Sequence[np.ndarray] = None, @@ -1746,6 +1559,8 @@ def reconstruct( -------- trotter_sign: str, optional Sign of single-side trotter to use. One of '+','-'. + phase_compensation: bool, optional + If True, the measured phase is compensated using a complex virtual detector. Recommnended. num_jobs: int, optional Number of processes to use. Default is None, which spawns as many processes as CPUs on the system. @@ -1776,6 +1591,8 @@ def reconstruct( sx, sy = self._grid_scan_shape psi = xp.empty((sx, sy), dtype=xp.complex64) probe_conj = xp.conj(self._fourier_probe) + aperture = xp.abs(self._fourier_probe) > 1e-3 + probe_kwargs = { "energy": self._energy, "gpts": self._intensities_shape, @@ -1811,8 +1628,10 @@ def reconstruct( Ky, self._fourier_probe, probe_conj, + aperture, probe_kwargs, trotter_sign=trotter_sign, + phase_compensation=phase_compensation, virtual_detector_masks=virtual_detector_masks, xp=xp, ) @@ -1845,8 +1664,10 @@ def wrapper_function(**kwargs): Ky=Ky, probe=self._fourier_probe, probe_conj=probe_conj, + aperture=aperture, probe_kwargs=probe_kwargs, trotter_sign=trotter_sign, + phase_compensation=phase_compensation, virtual_detector_masks=virtual_detector_masks, xp=xp, ) @@ -1868,7 +1689,7 @@ def wrapper_function(**kwargs): return self -class OBFPtychography( +class OBF( DirectPtychography, ): @@ -2085,7 +1906,7 @@ def wrapper_function(**kwargs): return self -class WDDPtychography( +class WDD( DirectPtychography, ):