Skip to content

Commit

Permalink
squeezing end of tally shape only
Browse files Browse the repository at this point in the history
  • Loading branch information
shimwell committed Nov 24, 2023
1 parent 138d70b commit 1f94d0e
Showing 1 changed file with 21 additions and 26 deletions.
47 changes: 21 additions & 26 deletions src/openmc_regular_mesh_plotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

_default_outline_kwargs = {"colors": "black", "linestyles": "solid", "linewidths": 1}

def _squeeze_end_of_array(array, dims_required=3):
while len(array.shape) > dims_required:
array = np.squeeze(array, axis=len(array.shape)-1)
return array

def plot_mesh_tally(
tally: "openmc.Tally",
Expand Down Expand Up @@ -112,19 +116,29 @@ def plot_mesh_tally(

tally_slice = tally.get_slice(scores=[score])

# if mesh.n_dimension == 3:
basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]

print('mesh.dimension', mesh.dimension)
if 1 in mesh.dimension:
index_of_2d = mesh.dimension.index(1)
axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d]
if axis_of_2d in basis: # checks if the axis is being plotted, e.g is 'x' in 'xy'
raise ValueError(
"The selected tally has a mesh that has 1 dimension in the "
f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
f"of {basis}."
)

# todo check if 1 appears twice or three times, raise value error if so
# TODO check if 1 appears twice or three times, raise value error if so

tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze()
tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value)#.squeeze()

basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]
if len(tally_data.shape) == 3:
tally_data = _squeeze_end_of_array(tally_data, dims_required=3)

# if len(tally_data.shape) == 3:
if mesh.n_dimension == 3:
if slice_index is None:
# finds the mid index
slice_index = int(tally_data.shape[basis_to_index] / 2)

if basis == "xz":
Expand All @@ -137,31 +151,12 @@ def plot_mesh_tally(
xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
else: # basis == 'xy'
slice_data = tally_data[:, :, slice_index]
print('shape slice_data', slice_data.shape)
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"
# elif mesh.n_dimension == 2:
elif len(tally_data.shape) == 2:
if basis_to_index == index_of_2d:
slice_data = tally_data[:, :]
if basis == "xz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
elif basis == "yz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
else: # basis == 'xy'
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"

else:
raise ValueError(
"The selected tally has a mesh that has 1 dimension in the "
f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
f"of {basis}."
)

else:
raise ValueError("mesh n_dimension")
raise ValueError(f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported")

if volume_normalization:
# in a regular mesh all volumes are the same so we just divide by the first
Expand Down

0 comments on commit 1f94d0e

Please sign in to comment.