Skip to content

Commit

Permalink
switched to argument form for multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Oct 20, 2024
1 parent 71941b9 commit 5eba668
Showing 1 changed file with 137 additions and 80 deletions.
217 changes: 137 additions & 80 deletions py4DSTEM/process/phase/direct_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,8 @@ def _reconstruct_single_frequency(
def reconstruct(
self,
trotter_sign="-",
worker_pool=None,
num_jobs=None,
threads_per_job=None,
progress_bar: bool = True,
device: str = None,
clear_fft_cache: bool = None,
Expand All @@ -1528,6 +1529,11 @@ def reconstruct(
--------
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
num_jobs: int, optional
Number of processes to use. If None, then as many processes as CPUs on
the system will be spawned.
threads_per_job: int, optional
Number of threads to use to avoid over-subscribing when using multiple processors.
worker_pool: WorkerPool, optional
If not None, reconstruction is dispatched to mpire WorkerPool instance.
progress_bar: bool, optional
Expand Down Expand Up @@ -1556,10 +1562,6 @@ def reconstruct(
self.copy_attributes_to_device(attrs, device)
self.set_device(device, clear_fft_cache)

if worker_pool is not None:
if self._device == "gpu":
raise NotImplementedError()

sx, sy = self._grid_scan_shape
psi = xp.empty((sx, sy), dtype=xp.complex64)
probe_conj = xp.conj(self._fourier_probe)
Expand All @@ -1586,7 +1588,7 @@ def reconstruct(

# main loop

if worker_pool is None:
if num_jobs == 1:
for ind_x, ind_y in tqdmnd(
sx,
sy,
Expand All @@ -1608,6 +1610,18 @@ def reconstruct(
xp=xp,
)
else:

if self._device == "gpu":
raise NotImplementedError()

from mpire import WorkerPool, cpu_count
from threadpoolctl import threadpool_limits

num_jobs = num_jobs or cpu_count()

if threads_per_job is not None:
num_jobs = num_jobs // threads_per_job

map_inputs = [
{
"intensities_FFT": self._intensities_FFT[ind_x, ind_y],
Expand All @@ -1619,21 +1633,24 @@ def reconstruct(
]

def wrapper_function(**kwargs):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
aperture=aperture,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
with threadpool_limits(limits=threads_per_job):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
aperture=aperture,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
)

with WorkerPool(n_jobs=num_jobs) as pool:
flat_results = pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)

flat_results = worker_pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)
for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results):
psi[ind_x, ind_y] = res

Expand Down Expand Up @@ -1713,7 +1730,8 @@ def _reconstruct_single_frequency(
def reconstruct(
self,
trotter_sign="-",
worker_pool=None,
num_jobs=None,
threads_per_job=None,
virtual_detector_masks: Sequence[np.ndarray] = None,
progress_bar: bool = True,
polar_parameters: Mapping[str, float] = None,
Expand All @@ -1726,8 +1744,13 @@ def reconstruct(
Parameters
--------
worker_pool: WorkerPool
If not None, reconstruction is dispatched to mpire WorkerPool instance.
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
num_jobs: int, optional
Number of processes to use. Default is None, which spawns as many processes as CPUs on
the system.
threads_per_job: int, optional
Number of threads to use to avoid over-subscribing when using multiple processors.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model trotters,
to allow comparison with arbitrary geometry detector datasets. TO-DO
Expand All @@ -1750,10 +1773,6 @@ def reconstruct(
**kwargs,
)

if worker_pool is not None:
if self._device == "gpu":
raise NotImplementedError()

sx, sy = self._grid_scan_shape
psi = xp.empty((sx, sy), dtype=xp.complex64)
probe_conj = xp.conj(self._fourier_probe)
Expand All @@ -1776,7 +1795,7 @@ def reconstruct(

# main loop

if worker_pool is None:
if num_jobs == 1:
for ind_x, ind_y in tqdmnd(
sx,
sy,
Expand All @@ -1798,6 +1817,16 @@ def reconstruct(
xp=xp,
)
else:
if self._device == "gpu":
raise NotImplementedError()

from mpire import WorkerPool, cpu_count
from threadpoolctl import threadpool_limits

num_jobs = num_jobs or cpu_count()
if threads_per_job is not None:
num_jobs = num_jobs // threads_per_job

map_inputs = [
{
"intensities_FFT": self._intensities_FFT[ind_x, ind_y],
Expand All @@ -1809,21 +1838,24 @@ def reconstruct(
]

def wrapper_function(**kwargs):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
with threadpool_limits(limits=threads_per_job):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)

with WorkerPool(n_jobs=num_jobs) as pool:
flat_results = pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)

flat_results = worker_pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)
for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results):
psi[ind_x, ind_y] = res

Expand Down Expand Up @@ -1907,7 +1939,8 @@ def _reconstruct_single_frequency(
def reconstruct(
self,
trotter_sign="-",
worker_pool=None,
num_jobs=None,
threads_per_job=None,
virtual_detector_masks: Sequence[np.ndarray] = None,
progress_bar: bool = True,
polar_parameters: Mapping[str, float] = None,
Expand All @@ -1920,8 +1953,13 @@ def reconstruct(
Parameters
--------
worker_pool: WorkerPool
If not None, reconstruction is dispatched to mpire WorkerPool instance.
trotter_sign: str, optional
Sign of single-side trotter to use. One of '+','-'.
num_jobs: int, optional
Number of processes to use. Default is None, which spawns as many processes as CPUs on
the system.
threads_per_job: int, optional
Number of threads to use to avoid over-subscribing when using multiple processors.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model trotters,
to allow comparison with arbitrary geometry detector datasets. TO-DO
Expand All @@ -1944,10 +1982,6 @@ def reconstruct(
**kwargs,
)

if worker_pool is not None:
if self._device == "gpu":
raise NotImplementedError()

sx, sy = self._grid_scan_shape
psi = xp.empty((sx, sy), dtype=xp.complex64)
probe_conj = xp.conj(self._fourier_probe)
Expand Down Expand Up @@ -1976,7 +2010,7 @@ def reconstruct(

# main loop

if worker_pool is None:
if num_jobs == 1:
for ind_x, ind_y in tqdmnd(
sx,
sy,
Expand All @@ -1999,6 +2033,16 @@ def reconstruct(
xp=xp,
)
else:
if self._device == "gpu":
raise NotImplementedError()

from mpire import WorkerPool, cpu_count
from threadpoolctl import threadpool_limits

num_jobs = num_jobs or cpu_count()
if threads_per_job is not None:
num_jobs = num_jobs // threads_per_job

map_inputs = [
{
"intensities_FFT": self._intensities_FFT[ind_x, ind_y],
Expand All @@ -2010,22 +2054,25 @@ def reconstruct(
]

def wrapper_function(**kwargs):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
probe_normalization=probe_normalization,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
with threadpool_limits(limits=threads_per_job):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
probe_conj=probe_conj,
probe_normalization=probe_normalization,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)

with WorkerPool(n_jobs=num_jobs) as pool:
flat_results = pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)

flat_results = worker_pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)
for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results):
psi[ind_x, ind_y] = res

Expand Down Expand Up @@ -2092,7 +2139,8 @@ def reconstruct(
self,
relative_wiener_epsilon,
trotter_sign="-",
worker_pool=None,
num_jobs=None,
threads_per_job=None,
virtual_detector_masks: Sequence[np.ndarray] = None,
progress_bar: bool = True,
polar_parameters: Mapping[str, float] = None,
Expand Down Expand Up @@ -2131,10 +2179,6 @@ def reconstruct(
**kwargs,
)

if worker_pool is not None:
if self._device == "gpu":
raise NotImplementedError()

sx, sy = self._grid_scan_shape
psi = xp.empty((sx, sy), dtype=xp.complex64)
wdd_probe_0 = xp.fft.ifft2(self._fourier_probe * self._fourier_probe.conj())
Expand All @@ -2160,7 +2204,7 @@ def reconstruct(

# main loop

if worker_pool is None:
if num_jobs == 1:
for ind_x, ind_y in tqdmnd(
sx,
sy,
Expand All @@ -2181,6 +2225,16 @@ def reconstruct(
xp=xp,
)
else:
if self._device == "gpu":
raise NotImplementedError()

from mpire import WorkerPool, cpu_count
from threadpoolctl import threadpool_limits

num_jobs = num_jobs or cpu_count()
if threads_per_job is not None:
num_jobs = num_jobs // threads_per_job

map_inputs = [
{
"intensities_FFT": self._intensities_FFT[ind_x, ind_y],
Expand All @@ -2192,20 +2246,23 @@ def reconstruct(
]

def wrapper_function(**kwargs):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
epsilon=epsilon,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
with threadpool_limits(limits=threads_per_job):
return self._reconstruct_single_frequency(
**kwargs,
Kx=Kx,
Ky=Ky,
probe=self._fourier_probe,
epsilon=epsilon,
probe_kwargs=probe_kwargs,
trotter_sign=trotter_sign,
xp=xp,
)

with WorkerPool(n_jobs=num_jobs) as pool:
flat_results = pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)

flat_results = worker_pool.map(
wrapper_function, map_inputs, progress_bar=progress_bar
)
for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results):
psi[ind_x, ind_y] = res

Expand Down

0 comments on commit 5eba668

Please sign in to comment.