Skip to content

Commit

Permalink
first pass at mpire
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Dec 7, 2024
1 parent 8989b29 commit b1f6141
Showing 1 changed file with 119 additions and 22 deletions.
141 changes: 119 additions & 22 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
from typing import Sequence, Tuple, Union

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from emdfile import tqdmnd

from py4DSTEM import show
import numpy as np
from py4DSTEM.datacube import DataCube
from py4DSTEM.preprocess.utils import bin2D
from py4DSTEM.process.calibration import fit_origin, get_origin
from py4DSTEM.process.diffraction import Crystal
from py4DSTEM.process.phase.utils import copy_to_device
from py4DSTEM.utils import fourier_resample
from py4DSTEM.visualize import return_scaled_histogram_ordering

from scipy.ndimage import rotate, zoom
from scipy.spatial.transform import Rotation as R

from mpire import WorkerPool, cpu_count
from threadpoolctl import threadpool_limits


try:
import cupy as cp

Expand Down Expand Up @@ -291,6 +297,9 @@ def reconstruct(
zero_edges: bool = True,
baseline_thresh: float = 0.9,
diffraction_gaussian_filter: float = 0,
distributed=False,
num_jobs=None,
threads_per_job=1,
):
"""
Main loop for reconstruct
Expand Down Expand Up @@ -342,28 +351,79 @@ def reconstruct(
self._diffraction_patterns_projected[a1_shuffle], device
)

for a2 in range(self._object_shape_6D[0]):
object_sliced = self._forward(
x_index=a2,
tilt_deg=self._tilt_deg[a1_shuffle],
num_points=num_points,
)
if distributed is False:
for a2 in range(self._object_shape_6D[0]):
x_index, yy, zz, update_r_summed = self._reconstruct(
a2=a2,
a1_shuffle=a1_shuffle,
num_points=num_points,
diffraction_patterns_projected=diffraction_patterns_projected,
error_iteration=error_iteration,
step_size=step_size,
)

update, error = self._calculate_update(
object_sliced=object_sliced,
diffraction_patterns_projected=diffraction_patterns_projected,
x_index=a2,
datacube_number=a1_shuffle,
)
self._object[x_index, yy, zz] += update_r_summed

error_iteration += error
elif distributed is True and self._device == "cpu":
# obj = self._object
num_jobs = num_jobs or cpu_count() // threads_per_job

update *= step_size
self._back(
num_points=num_points,
x_index=a2,
update=update,
)
def f(args):
with threadpool_limits(limits=threads_per_job):
return self._reconstruct(**args)

# hopefully the data entries remain as views until dispatch time...
inputs = [
(
{
"a2": a2,
"a1_shuffle": a1_shuffle,
"num_points": num_points,
"diffraction_patterns_projected": diffraction_patterns_projected,
"error_iteration": error_iteration,
"step_size": step_size,
},
)
for a2 in range(self._object_shape_6D[0])
]

with WorkerPool(
n_jobs=num_jobs,
) as pool:
results = pool.map(
f,
inputs,
progress_bar=False,
)

for a2 in range(self._object_shape_6D[0]):
x_index, yy, zz, update_r_summed = results[a2]
self._object[x_index, yy, zz] += update_r_summed

else:
raise ValueError(("distributed not implemented for put"))

# object_sliced = self._forward(
# x_index=a2,
# tilt_deg=self._tilt_deg[a1_shuffle],
# num_points=num_points,
# )

# update, error = self._calculate_update(
# object_sliced=object_sliced,
# diffraction_patterns_projected=diffraction_patterns_projected,
# x_index=a2,
# datacube_number=a1_shuffle,
# )

# error_iteration += error

# update *= step_size
# self._back(
# num_points=num_points,
# x_index=a2,
# update=update,
# )

self._constraints(
zero_edges=zero_edges,
Expand All @@ -378,6 +438,40 @@ def reconstruct(

return self

def _reconstruct(
self,
a2,
a1_shuffle,
num_points,
diffraction_patterns_projected,
error_iteration,
step_size,
):
object_sliced = self._forward(
x_index=a2,
tilt_deg=self._tilt_deg[a1_shuffle],
num_points=num_points,
)

update, error = self._calculate_update(
object_sliced=object_sliced,
diffraction_patterns_projected=diffraction_patterns_projected,
x_index=a2,
datacube_number=a1_shuffle,
)

error_iteration += error

update *= step_size
(x_index, yy, zz, update_r_summed) = self._back(
num_points=num_points,
x_index=a2,
update=update,
)

return x_index, yy, zz, update_r_summed
# obj[x_index, yy, zz] += update_r_summed

def _prepare_datacube(
self,
datacube_number,
Expand Down Expand Up @@ -1395,7 +1489,8 @@ def _back(
yy = copy_to_device(yy, storage)
zz = copy_to_device(zz, storage)

self._object[x_index, yy, zz] += copy_to_device(update_r_summed, storage)
# self._object[x_index, yy, zz] += copy_to_device(update_r_summed, storage)
return x_index, yy, zz, copy_to_device(update_r_summed, storage)

def _constraints(
self,
Expand Down Expand Up @@ -1446,7 +1541,9 @@ def _constraints(
s = self._object.shape

obj_6D = self.object_6D
obj_6D = gaussian_filter(obj_6D, diffraction_gaussian_filter, axes = (-1,-2,-3))
obj_6D = gaussian_filter(
obj_6D, diffraction_gaussian_filter, axes=(-1, -2, -3)
)

self._object = obj_6D.reshape(s)

Expand Down

0 comments on commit b1f6141

Please sign in to comment.