Skip to content

Commit

Permalink
slim_preprocess bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 14, 2024
1 parent 91d3bb5 commit 6467b16
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 18 deletions.
2 changes: 1 addition & 1 deletion py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,7 @@ def slim_preprocess(
self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe))

self._object_initial = self._object.copy()
self._object_initial_type = self._object_type
self._object_type_initial = self._object_type

self._positions_initial = self.positions

Expand Down
25 changes: 8 additions & 17 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def slim_preprocess(
list_of_probe_arrays,
object_array,
list_of_positions_px,
main_tilt_axis: str = "vertical",
reciprocal_sampling=None,
angular_sampling=None,
store_initial_arrays=True,
Expand Down Expand Up @@ -271,15 +270,16 @@ def slim_preprocess(
num_probes_per_measurement = [0] + [amp.shape[0] for amp in list_of_amplitudes]
num_probes_per_measurement = np.array(num_probes_per_measurement)

self._mean_diffraction_intensity = []
self._probes_all = []
self._mean_diffraction_intensity = []
self._num_diffraction_patterns = num_probes_per_measurement.sum()

self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement)
self._positions_px_all = np.empty((self._num_diffraction_patterns, 2))

self._amplitudes_shape = np.array(list_of_amplitudes[0][0].shape)
self._amplitudes = xp_storage.empty(
(self._num_diffraction_patterns,) + self._amplitudes_shape
(self._num_diffraction_patterns,) + tuple(self._amplitudes_shape)
)
self._object = xp.asarray(object_array, dtype=xp.float32)

Expand All @@ -303,10 +303,10 @@ def slim_preprocess(
self._positions_px_all[idx_start:idx_end] = pos
self._mean_diffraction_intensity.append((amps**2).sum((-1, -2)).mean(0))

self._probes_all.append(probe)
self._probes_all.append(probe.copy())
if store_initial_arrays:
self._probe_initial = probe.copy()
self._probe_initial_aperture = xp.abs(xp.fft.fft2(probe))
self._probes_all_initial.append(probe.copy())
self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(probe)))

# specify sampling
if angular_sampling is None and reciprocal_sampling is None:
Expand Down Expand Up @@ -339,16 +339,7 @@ def slim_preprocess(
self._num_voxels = self._object.shape[0]
self._object_fov_mask_inverse = np.full(self._object_shape, False)

# Precomputed propagator arrays
if main_tilt_axis == "vertical":
thickness = self._object_shape[1] * self.sampling[1]
elif main_tilt_axis == "horizontal":
thickness = self._object_shape[0] * self.sampling[0]
else:
thickness_h = self._object_shape[1] * self.sampling[1]
thickness_v = self._object_shape[0] * self.sampling[0]
thickness = max(thickness_h, thickness_v)

thickness = self._num_voxels * self.sampling[0]
self._slice_thicknesses = np.tile(
thickness / self._num_slices, self._num_slices - 1
)
Expand All @@ -367,7 +358,7 @@ def slim_preprocess(
# necessary restarting attributes
if store_initial_arrays:
self._object_initial = self._object.copy()
self._object_initial_type = self._object_type
self._object_type_initial = self._object_type

self._positions_px_initial_all = self._positions_px_all.copy()
self._positions_initial_all = self._positions_px_initial_all.copy()
Expand Down

0 comments on commit 6467b16

Please sign in to comment.