Skip to content

Commit

Permalink
cleaned up classes, removed trailing Ptychography
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Oct 20, 2024
1 parent 5eba668 commit 60f9e94
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 232 deletions.
2 changes: 1 addition & 1 deletion py4DSTEM/process/phase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
_emd_hook = True

from py4DSTEM import is_package_lite
from py4DSTEM.process.phase.direct_ptychography import OBFPtychography, PhaseCompensatedSSBPtychography, SSBPtychography, WDDPtychography
from py4DSTEM.process.phase.direct_ptychography import OBF, SSB, WDD
from py4DSTEM.process.phase.dpc import DPC
from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography
from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography
Expand Down
283 changes: 52 additions & 231 deletions py4DSTEM/process/phase/direct_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,20 +1443,12 @@ def sampling(self):
)


class SSBPtychography(
class SSB(
DirectPtychography,
):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if sum(self._polar_parameters.values()) != 0.0:
warnings.warn(
"Note aberrations will effectively be ignored in this class",
UserWarning,
)

def aberration_fit(self, *args, **kwargs):
raise NotImplementedError()

def _reconstruct_single_frequency(
self,
Expand All @@ -1470,14 +1462,16 @@ def _reconstruct_single_frequency(
aperture,
probe_kwargs,
trotter_sign,
phase_compensation: bool = True,
virtual_detector_masks: Sequence[np.ndarray] = None,
xp=np,
):
""" """
threshold = 1e-3
G = xp.asarray(intensities_FFT)
if Qx == 0.0 and Qy == 0.0:
return xp.abs(G).sum()
else:

Kx_plus_Qx = Kx + Qx
Ky_plus_Qy = Ky + Qy

Expand All @@ -1486,7 +1480,7 @@ def _reconstruct_single_frequency(
force_spatial_frequencies=(Kx_plus_Qx, Ky_plus_Qy),
)
alpha_plus, phi_plus = cmplx_probe_plus.get_scattering_angles()
aperture_plus = cmplx_probe_plus.evaluate_aperture(alpha_plus, phi_plus) > 0
aperture_plus = cmplx_probe_plus.evaluate_aperture(alpha_plus, phi_plus)

Kx_minus_Qx = Kx - Qx
Ky_minus_Qy = Ky - Qy
Expand All @@ -1496,240 +1490,59 @@ def _reconstruct_single_frequency(
force_spatial_frequencies=(Kx_minus_Qx, Ky_minus_Qy),
)
alpha_minus, phi_minus = cmplx_probe_minus.get_scattering_angles()
aperture_minus = (
cmplx_probe_minus.evaluate_aperture(alpha_minus, phi_minus) > 0
)
aperture_minus = cmplx_probe_minus.evaluate_aperture(alpha_minus, phi_minus)

if trotter_sign == "+":
aperture_plus_solo = xp.logical_and(
xp.logical_and(aperture, aperture_plus), ~aperture_minus
if phase_compensation:
aberrations_plus = cmplx_probe_plus.evaluate_aberrations(
alpha_plus, phi_plus
)
return G[aperture_plus_solo].sum().conj()
elif trotter_sign == "-":
aperture_minus_solo = xp.logical_and(
xp.logical_and(aperture, aperture_minus), ~aperture_plus
aberrations_minus = cmplx_probe_minus.evaluate_aberrations(
alpha_minus, phi_minus
)
return G[aperture_minus_solo].sum()
else:
raise ValueError()

def reconstruct(
self,
trotter_sign="-",
num_jobs=None,
threads_per_job=None,
progress_bar: bool = True,
device: str = None,
clear_fft_cache: bool = None,
):
"""
Ptychographic reconstruction main method.
Parameters
--------
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
num_jobs: int, optional
Number of processes to use. If None, then as many processes as CPUs on
the system will be spawned.
threads_per_job: int, optional
Number of threads to use to avoid over-subscribing when using multiple processors.
worker_pool: WorkerPool, optional
If not None, reconstruction is dispatched to mpire WorkerPool instance.
progress_bar: bool, optional
If True, reconstruction progress is displayed
Returns
--------
self: DirectPtychography
Self to accommodate chaining
"""

xp = self._xp
asnumpy = self._asnumpy

# handle device/storage
if device == "gpu":
warnings.warn(
"Note this class is not very well optimized on gpu.",
UserWarning,
)
probe_plus = aperture_plus * aberrations_plus
probe_minus = aperture_minus * aberrations_minus
gamma = probe_conj * probe_minus - probe * probe_plus.conj()

if device is not None:
attrs = [
"_fourier_probe_initial",
]
self.copy_attributes_to_device(attrs, device)
self.set_device(device, clear_fft_cache)

sx, sy = self._grid_scan_shape
psi = xp.empty((sx, sy), dtype=xp.complex64)
probe_conj = xp.conj(self._fourier_probe)
probe_kwargs = {
"energy": self._energy,
"gpts": self._intensities_shape,
"sampling": self.sampling,
"semiangle_cutoff": self._semiangle_cutoff,
"vacuum_probe_intensity": self._vacuum_probe_intensity,
"rolloff": self._rolloff,
"device": self._device,
}

Kx, Ky = self._spatial_frequencies
Qx, Qy = self._scan_frequencies

cmplx_probe = ComplexProbe(
**probe_kwargs,
force_spatial_frequencies=(Kx, Ky),
)

alpha, phi = cmplx_probe.get_scattering_angles()
aperture = cmplx_probe.evaluate_aperture(alpha, phi) > 0

# main loop

if num_jobs == 1:
for ind_x, ind_y in tqdmnd(
sx,
sy,
desc="Reconstructing object",
unit="freq.",
disable=not progress_bar,
):
psi[ind_x, ind_y] = self._reconstruct_single_frequency(
self._intensities_FFT[ind_x, ind_y],
Qx[ind_x, ind_y],
Qy[ind_x, ind_y],
Kx,
Ky,
self._fourier_probe,
probe_conj,
aperture,
probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
)
else:

if self._device == "gpu":
raise NotImplementedError()

from mpire import WorkerPool, cpu_count
from threadpoolctl import threadpool_limits

num_jobs = num_jobs or cpu_count()

if threads_per_job is not None:
num_jobs = num_jobs // threads_per_job

map_inputs = [
{
"intensities_FFT": self._intensities_FFT[ind_x, ind_y],
"Qx": Qx[ind_x, ind_y],
"Qy": Qy[ind_x, ind_y],
}
for ind_x in range(sx)
for ind_y in range(sy)
]

def wrapper_function(**kwargs):
with threadpool_limits(limits=threads_per_job):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
aperture=aperture,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
if virtual_detector_masks is not None:
gamma = mask_array_using_virtual_detectors(
gamma, virtual_detector_masks, in_place=True
)

with WorkerPool(n_jobs=num_jobs) as pool:
flat_results = pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)

for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results):
psi[ind_x, ind_y] = res

self._object = xp.fft.ifft2(psi) / self._mean_diffraction_intensity

# store result
self.object = asnumpy(self._object)
self.clear_device_mem(self._device, self._clear_fft_cache)

return self


class PhaseCompensatedSSBPtychography(
DirectPtychography,
):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _reconstruct_single_frequency(
self,
intensities_FFT,
Qx,
Qy,
Kx,
Ky,
probe,
probe_conj,
probe_kwargs,
trotter_sign,
virtual_detector_masks: Sequence[np.ndarray] = None,
xp=np,
):
""" """
threshold = 1e-3

G = xp.asarray(intensities_FFT)
if Qx == 0.0 and Qy == 0.0:
return xp.abs(G).sum()
else:
Kx_plus_Qx = Kx + Qx
Ky_plus_Qy = Ky + Qy

probe_plus = ComplexProbe(
**probe_kwargs,
force_spatial_frequencies=(Kx_plus_Qx, Ky_plus_Qy),
)._evaluate_ctf()

Kx_minus_Qx = Kx - Qx
Ky_minus_Qy = Ky - Qy
gamma_abs = xp.abs(gamma)
gamma_ind = gamma_abs > threshold
normalization = gamma_abs[gamma_ind]

probe_minus = ComplexProbe(
**probe_kwargs,
force_spatial_frequencies=(Kx_minus_Qx, Ky_minus_Qy),
)._evaluate_ctf()

gamma = probe_conj * probe_minus - probe * probe_plus.conj()

if virtual_detector_masks is not None:
gamma = mask_array_using_virtual_detectors(
gamma, virtual_detector_masks, in_place=True
)
if trotter_sign == "+":
numerator = -G[gamma_ind].conj() * gamma[gamma_ind]
elif trotter_sign == "-":
numerator = G[gamma_ind] * gamma[gamma_ind].conj()
else:
raise ValueError()

gamma_abs = xp.abs(gamma)
gamma_ind = gamma_abs > threshold
return (numerator / normalization).sum()

normalization = gamma_abs[gamma_ind]
if trotter_sign == "+":
numerator = -G[gamma_ind].conj() * gamma[gamma_ind]
elif trotter_sign == "-":
numerator = G[gamma_ind] * gamma[gamma_ind].conj()
else:
raise ValueError()
aperture_plus = aperture_plus > threshold
aperture_minus = aperture_minus > threshold

return (numerator / normalization).sum()
if trotter_sign == "+":
aperture_solo = xp.logical_and(
xp.logical_and(aperture, aperture_plus), ~aperture_minus
)
return G[aperture_solo].sum().conj()
elif trotter_sign == "-":
aperture_solo = xp.logical_and(
xp.logical_and(aperture, aperture_minus), ~aperture_plus
)
return G[aperture_solo].sum()
else:
raise ValueError()

def reconstruct(
self,
trotter_sign="-",
phase_compensation=True,
num_jobs=None,
threads_per_job=None,
virtual_detector_masks: Sequence[np.ndarray] = None,
Expand All @@ -1746,6 +1559,8 @@ def reconstruct(
--------
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
phase_compensation: bool, optional
If True, the measured phase is compensated using a complex virtual detector. Recommnended.
num_jobs: int, optional
Number of processes to use. Default is None, which spawns as many processes as CPUs on
the system.
Expand Down Expand Up @@ -1776,6 +1591,8 @@ def reconstruct(
sx, sy = self._grid_scan_shape
psi = xp.empty((sx, sy), dtype=xp.complex64)
probe_conj = xp.conj(self._fourier_probe)
aperture = xp.abs(self._fourier_probe) > 1e-3

probe_kwargs = {
"energy": self._energy,
"gpts": self._intensities_shape,
Expand Down Expand Up @@ -1811,8 +1628,10 @@ def reconstruct(
Ky,
self._fourier_probe,
probe_conj,
aperture,
probe_kwargs,
trotter_sign=trotter_sign,
phase_compensation=phase_compensation,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)
Expand Down Expand Up @@ -1845,8 +1664,10 @@ def wrapper_function(**kwargs):
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
aperture=aperture,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
phase_compensation=phase_compensation,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)
Expand All @@ -1868,7 +1689,7 @@ def wrapper_function(**kwargs):
return self


class OBFPtychography(
class OBF(
DirectPtychography,
):

Expand Down Expand Up @@ -2085,7 +1906,7 @@ def wrapper_function(**kwargs):
return self


class WDDPtychography(
class WDD(
DirectPtychography,
):

Expand Down

0 comments on commit 60f9e94

Please sign in to comment.