diff --git a/src/openmc_source_plotter/core.py b/src/openmc_source_plotter/core.py index 34e7f83..6392b26 100644 --- a/src/openmc_source_plotter/core.py +++ b/src/openmc_source_plotter/core.py @@ -37,7 +37,7 @@ def sample_initial_particles(self, n_samples: int = 1000, prn_seed: int = None): materials = openmc.Materials() model.materials = materials - sph = openmc.Sphere(r=9999999999, boundary_type="vacuum") + sph = openmc.Sphere(r=99999999999, boundary_type="vacuum") cell = openmc.Cell(region=-sph) geometry = openmc.Geometry([cell]) model.geometry = geometry @@ -72,6 +72,9 @@ def plot_source_energy( prn_seed: int = 1, energy_bins: typing.Union[str, np.array] = "auto", name: typing.Optional[str] = None, + yaxis_type: str = "linear", + xaxis_type: str = "linear", + xaxis_units: str = "MeV", ): """makes a plot of the initial creation positions of an OpenMC source @@ -88,14 +91,22 @@ def plot_source_energy( Numpy bins can also be manually set by passing in a numpy array of bin edges. name: the legend name to use + yaxis_type: The type (scale) to use for the Y axis. Options are 'log' + or 'linear. + xaxis_type: The type (scale) to use for the Y axis. Options are 'log' + or 'linear. + xaxis_units: The units to use for the x axis. Options are 'eV' or 'MeV'. """ + if xaxis_units not in ["eV", "MeV"]: + raise ValueError(f"xaxis_units must be either 'eV' or 'MeV' not {xaxis_units}") + if figure is None: figure = plotly.graph_objects.Figure() figure.update_layout( title="Particle energy", - xaxis={"title": "Energy (eV)"}, - yaxis={"title": "Probability"}, + xaxis={"title": f"Energy [{xaxis_units}]", "type": xaxis_type}, + yaxis={"title": "Probability", "type": yaxis_type}, showlegend=True, ) @@ -109,11 +120,13 @@ def plot_source_energy( # scaling by strength if isinstance(self, openmc.SourceBase): probability = probability * self.strength - + energy = bin_edges[:-1] + if xaxis_units == "MeV": + energy = energy / 1e6 # Plot source energy histogram figure.add_trace( plotly.graph_objects.Scatter( - x=bin_edges[:-1], + x=energy, y=probability * np.diff(bin_edges), line={"shape": "hv"}, hoverinfo="text", diff --git a/tests/test_core_with_source.py b/tests/test_core_with_source.py index 56f75e6..ed54bf0 100644 --- a/tests/test_core_with_source.py +++ b/tests/test_core_with_source.py @@ -48,6 +48,17 @@ def test_energy_plot(test_source): assert len(plot.data[0]["x"]) == 1 +def test_energy_plot_axis(test_source): + plot = test_source.plot_source_energy( + n_samples=10, xaxis_type="log", yaxis_type="linear", xaxis_units="eV" + ) + plot = test_source.plot_source_energy( + n_samples=10, xaxis_type="linear", yaxis_type="log", xaxis_units="MeV" + ) + assert isinstance(plot, go.Figure) + assert len(plot.data[0]["x"]) == 1 + + def test_position_plot(test_source): plot = test_source.plot_source_position(n_samples=10) assert isinstance(plot, go.Figure)