Skip to content

Commit

Permalink
Merge pull request #10 from fusion-energy/update_gui_to_latest_api
Browse files Browse the repository at this point in the history
Update GUI to make use of the latest api
  • Loading branch information
shimwell authored Sep 13, 2023
2 parents 77d97c9 + 048f495 commit 265c13b
Showing 1 changed file with 116 additions and 88 deletions.
204 changes: 116 additions & 88 deletions src/openmc_cylindrical_mesh_plotter/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import openmc
import streamlit as st
from matplotlib.colors import LogNorm
import openmc_cylindrical_mesh_plotter as cmp
import matplotlib.pyplot as plt
from openmc_cylindrical_mesh_plotter import (
plot_mesh_tally_rz_slice,
plot_mesh_tally_phir_slice,
)
import openmc


def save_uploadedfile(uploadedfile):
Expand All @@ -11,6 +13,40 @@ def save_uploadedfile(uploadedfile):
return st.success(f"Saved File to {uploadedfile.name}")


def get_tallies_with_cylindrical_mesh_filters(statepoint: openmc.StatePoint):
"""scans the statepoint object to find all tallies and with cylindrical mesh
filters, returns a list of tally indexes"""

matching_tally_ids = []
for _, tally in statepoint.tallies.items():
try:
mf = tally.find_filter(filter_type=openmc.MeshFilter)
if isinstance(mf.mesh, openmc.CylindricalMesh):
matching_tally_ids.append(tally.id)
print("found regmeshfilter")
except ValueError:
mf = None

return sorted(matching_tally_ids)


def get_cylindricalmesh_tallies_and_scores(statepoint: openmc.StatePoint):
"""scans the statepoint object to find all tallies and scores,
returns list of dictionaries. Each dictionary contains tally id,
score and tally name"""

tallies_of_interest = get_tallies_with_cylindrical_mesh_filters(statepoint)

tally_score_info = []
for tally_id in tallies_of_interest:
tally = statepoint.tallies[tally_id]
for score in tally.scores:
entry = {"id": tally.id, "score": score, "name": tally.name}
tally_score_info.append(entry)

return tally_score_info


def header():
"""This section writes out the page header common to all tabs"""

Expand Down Expand Up @@ -72,7 +108,7 @@ def main():
save_uploadedfile(statepoint_file)
statepoint = openmc.StatePoint(statepoint_file.name)

tally_description = cmp.get_cylindricalmesh_tallies_and_scores(statepoint)
tally_description = get_cylindricalmesh_tallies_and_scores(statepoint)
tally_description_str = [
f"ID={td['id']} score={td['score']} name={td['name']}"
for td in tally_description
Expand All @@ -82,65 +118,69 @@ def main():
label="Tally to plot", options=tally_description_str, index=0
)
tally_id_to_plot = tally_description_to_plot.split(" ")[0][3:]
tally_score_to_plot = tally_description_to_plot.split(" ")[1][6:]
score = tally_description_to_plot.split(" ")[1][6:]

view_direction = st.sidebar.selectbox(
basis = st.sidebar.selectbox(
label="view direction",
options=("PhiR", "RZ"),
index=0,
key="axis",
help="",
)

tally_or_std = st.sidebar.radio(
"Tally mean or std dev", options=["mean", "std_dev"]
value = st.sidebar.radio("Tally mean or std dev", options=["mean", "std_dev"])

axis_units = st.sidebar.selectbox(
"Axis units", ["km", "m", "cm", "mm"], index=2
)

volume_normalization = st.sidebar.radio(
"Divide value by mesh voxel volume", options=[True, False]
)

value_multiple = st.sidebar.number_input(
"Multiplier value",
scaling_factor = st.sidebar.number_input(
"Scaling factor",
value=1.0,
help="Input a number that will be used to scale the mesh values. For example a input of 2 would double all the values.",
)

my_tally = statepoint.get_tally(id=int(tally_id_to_plot))
score = my_tally.get_values(
scores=[tally_score_to_plot], value=tally_or_std
).flatten()
mesh = my_tally.find_filter(filter_type=openmc.MeshFilter).mesh
extent = mesh.get_mpl_plot_extent(view_direction=view_direction)

slice_index = st.sidebar.slider(
label="slice index",
min_value=1,
value=1,
max_value=mesh.get_number_of_slices(view_direction=view_direction),
)
tally = statepoint.get_tally(id=int(tally_id_to_plot))

contour_levels_str = st.sidebar.text_input(
"Contour levels",
help="Optionally add some comma deliminated contour values",
)
mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh

if contour_levels_str:
contour_levels = sorted(
[float(v) for v in contour_levels_str.strip().split(",")]
)
if basis == "RZ":
max_value = int(tally.shape[1] / 2) # index 1 is the phi value
if basis == "PhiR":
max_value = int(tally.shape[2] / 2) # index 2 is the z value

if max_value == 0:
slice_index = 0
else:
contour_levels = None
slice_index = st.sidebar.slider(
label="slice index",
min_value=0,
value=int(max_value / 2),
max_value=max_value,
)

# contour_levels_str = st.sidebar.text_input(
# "Contour levels",
# help="Optionally add some comma deliminated contour values",
# )

if tally_or_std == "mean":
cbar_label = tally_score_to_plot
else: # 'std dev'
cbar_label = f"standard deviation {tally_score_to_plot}"
# if contour_levels_str:
# contour_levels = sorted(
# [float(v) for v in contour_levels_str.strip().split(",")]
# )
# else:
# contour_levels = None

colorbar = st.sidebar.radio("Include colorbar", options=[True, False])

title = st.sidebar.text_input(
"Colorbar title",
help="Optionally set your own colorbar label for the plot",
value=cbar_label,
value="colorbar title",
)

log_lin_scale = st.sidebar.radio("Scale", options=["log", "linear"])
Expand All @@ -149,68 +189,56 @@ def main():
else:
norm = LogNorm()

xlabel, ylabel = mesh.get_axis_labels(view_direction=view_direction)

if view_direction == "RZ":
image_slice = mesh.slice_of_data(
dataset=score, # ,
view_direction=view_direction,
if basis == "RZ":
plot = plot_mesh_tally_rz_slice(
tally=tally,
slice_index=slice_index,
score=score,
axes=None,
axis_units=axis_units,
value=value,
# outline: bool = False,
# outline_by: str = "cell",
# geometry: Optional["openmc.Geometry"] = None,
# geometry_basis: str = "xz",
# pixels: int = 40000,
colorbar=colorbar,
volume_normalization=volume_normalization,
scaling_factor=scaling_factor,
colorbar_kwargs={"label": title},
norm=norm
# outline_kwargs: dict = _default_outline_kwargs,
# **kwargs,
)
image_slice = image_slice * value_multiple

axes = plt.subplot(1, 1, 1)
(xlabel, ylabel) = mesh.get_axis_labels(view_direction=view_direction)

axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
im = axes.imshow(X=image_slice, extent=extent, norm=norm)

if contour_levels_str:
axes.contour(
image_slice,
levels=contour_levels,
colors="black",
linewidths=1,
extent=extent,
)
else:
extent = None # not yet figured out
theta, r, values = mesh.slice_of_data(
dataset=score, # ,
view_direction=view_direction,
slice_index=slice_index,
volume_normalization=volume_normalization,
elif basis == "PhiR":
plot = plot_mesh_tally_phir_slice(
tally=tally, # "openmc.Tally",
slice_index=slice_index, # Optional[int] = None,
score=score, # Optional[str] = None,
# axes,# Optional[str] = None,
axis_units=axis_units, # str = "cm",
value=value, # str = "mean",
# outline,# bool = False,
# outline_by,# str = "cell",
# geometry,# Optional["openmc.Geometry"] = None,
# pixels,# int = 40000,
colorbar=colorbar, # bool = True,
volume_normalization=volume_normalization, # bool = True,
scaling_factor=scaling_factor, # Optional[float] = None,
colorbar_kwargs={"label": title},
norm=norm,
# outline_kwargs,# d
)
values = values * value_multiple
fig, axes = plt.subplots(subplot_kw=dict(projection="polar"))
im = axes.contourf(
theta, r, values, extent=extent, norm=norm
) # , locator=ticker.LogLocator())

if contour_levels_str:
axes.contour(
theta,
r,
values,
levels=contour_levels,
colors="darkgrey",
linewidths=1,
extent=extent,
)

plt.colorbar(im, label=title, ax=axes)

plt.savefig("openmc_plot_cylindricalmesh_image.png")

plot.figure.savefig("openmc_plot_cylindricalmesh_image.png")
st.pyplot(plot.figure)
with open("openmc_plot_cylindricalmesh_image.png", "rb") as file:
st.download_button(
label="Download image",
data=file,
file_name="openmc_plot_cylindricalmesh_image.png",
mime="image/png",
)
st.pyplot(plt)


if __name__ == "__main__":
Expand Down

0 comments on commit 265c13b

Please sign in to comment.