Skip to content

Commit

Permalink
fixed one bug in reconstruct...
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Aug 31, 2024
1 parent 019c296 commit f9880ba
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,16 +766,17 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
]
return diffraction_patterns_reshaped

def _reshape_2D_array_to_4D(self, data):
def _reshape_2D_array_to_4D(self, data, xy_shape = None):
"""
reshape ravelled diffraction 2D-data to 4D-data
Parameters
----------
data: np.ndarrray
2D datacube data to be reshapped
xy_shape: 2-tuple
if None, takes 6D object shape
Returns
--------
data_reshaped: np.ndarray
Expand All @@ -784,12 +785,20 @@ def _reshape_2D_array_to_4D(self, data):
"""
xp = self._xp

s = (
self._object_shape_6D[0],
self._object_shape_6D[1],
self._object_shape_6D[-1],
self._object_shape_6D[-1],
)
if xy_shape is None:
s = (
self._object_shape_6D[0],
self._object_shape_6D[1],
self._object_shape_6D[-1],
self._object_shape_6D[-1],
)
else:
s = (
xy_shape[0],
xy_shape[1],
self._object_shape_6D[-1],
self._object_shape_6D[-1],
)
a = xp.argsort(self._ind_diffraction_ravel[self._circular_mask_ravel])
i = xp.empty_like(a)
i[a] = xp.arange(a.size)
Expand Down Expand Up @@ -1018,6 +1027,8 @@ def _forward(
line_y_diff = xp.arange(-(s[-1] - 1) / 2, s[-1] / 2) * length / s[-1]
line_z_diff = line_y_diff * xp.tan(tilt) + (s[-1] - 1) / 2
line_y_diff += (s[-1] - 1) / 2
# line_y_diff = np.fft.fftfreq(s[-1], 1 / s[-1]) * xp.cos(tilt) + (s[-1]-1)/2
# line_z_diff = np.fft.fftfreq(s[-1], 1 / s[-1]) * xp.sin(tilt) + (s[-1]-1)/2

yF_diff = xp.floor(line_y_diff).astype("int")
zF_diff = xp.floor(line_z_diff).astype("int")
Expand Down Expand Up @@ -1243,7 +1254,7 @@ def _back(
),
4 * num_points,
axis=0,
) / (4 * num_points)
) / (num_points)

real_index = xp.ravel_multi_index(
(self._ind0.ravel(), self._ind1.ravel()), (s[1], s[2]), mode="clip"
Expand Down Expand Up @@ -1589,8 +1600,9 @@ def visualize(self, plot_convergence=True, figsize=(10, 10)):
)

ax = fig.add_subplot(spec[0, 1])
ind_diff = self._object_shape_6D[-1]//2
show(
self.object_6D.mean((0, 1, 2, 5)),
self.object_6D.mean((0, 1, 2))[:,:,ind_diff],
figax=(fig, ax),
cmap="magma",
title="diffraction space object",
Expand Down

0 comments on commit f9880ba

Please sign in to comment.