Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Callback function for collecting additional data from Monte Carlo sims #702

Merged
merged 15 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@
"source": [
"We have shown, so far, how to perform to use the `MonteCarlo` class and visualize its results. By default, some variables exported to the output files, such as *apogee* and *x_impact*. The `export_list` argument provides a simplified way for the user to export additional variables listed in the documentation, such as *inclination* and *heading*. \n",
"\n",
"There are applications in which you might need to extract more information in the results than the `export_list` argument can handle. To that end, the `MonteCarlo` class has a `export_function` argument which allows you customize further the output of the simulation.\n",
"There are applications in which you might need to extract more information in the results than the `export_list` argument can handle. To that end, the `MonteCarlo` class has a `data_collector` argument which allows you customize further the output of the simulation.\n",
"\n",
"To exemplify its use, we show how to export the *date* of the environment used in the simulation together with the *average reynolds number* along with the default variables."
]
Expand All @@ -1138,36 +1138,42 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the `stochastic_env`, `stochastic_rocket` and `stochastic_flight` objects previously defined, and only change the `MonteCarlo` object. First, we need to define our customized export function."
"We will use the `stochastic_env`, `stochastic_rocket` and `stochastic_flight` objects previously defined, and only change the `MonteCarlo` object. First, we need to define our customized data collector."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"def custom_export_function(flight):\n",
"# Defining custom callback functions\n",
"def get_average_reynolds_number(flight):\n",
" reynold_number_list = flight.reynolds_number(flight.time)\n",
" average_reynolds_number = np.mean(reynold_number_list)\n",
" custom_exports = {\n",
" \"average_reynolds_number\": average_reynolds_number,\n",
" \"date\": flight.env.date,\n",
" }\n",
" return custom_exports"
" return average_reynolds_number\n",
"\n",
"\n",
"def get_date(flight):\n",
" return flight.env.date\n",
"\n",
"\n",
"custom_data_collector = {\n",
" \"average_reynolds_number\": get_average_reynolds_number,\n",
" \"date\": get_date,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `data_collector` must be a dictionary whose keys are the names of the variables we want to export and the values are callback functions (python callables) that compute these variable values. Notice how we can compute complex expressions in this function and just export the result. For instance, the *get_average_reynolds_number* calls the `flight.reynolds_number` method for each value in `flight.time` list and computes the average value using numpy's `mean`. The *date* variable is straightforward.\n",
"\n",
"The `export_function` must be a function which takes a `Flight` object and outputs a dictionary whose keys are variables names to export and their values. Notice how we can compute complex expressions in this function and just export the result. For instance, the *average_reynolds_number* calls the `flight.reynolds_number` method for each value in `flight.time` list and computes the average value using numpy's `mean`. The *date* variable is straightforward.\n",
"\n",
"After we define the export function, we pass it as an argument to the `MonteCarlo` class."
"After we define the data collector, we pass it as an argument to the `MonteCarlo` class."
]
},
{
Expand All @@ -1181,7 +1187,8 @@
" environment=stochastic_env,\n",
" rocket=stochastic_rocket,\n",
" flight=stochastic_flight,\n",
" export_function=custom_export_function,\n",
" export_list=[\"apogee\", \"apogee_time\", \"x_impact\"],\n",
" data_collector=custom_data_collector,\n",
")"
]
},
Expand Down
83 changes: 58 additions & 25 deletions rocketpy/simulation/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class MonteCarlo:
The stochastic flight object to be iterated over.
export_list : list
The list of variables to export at each simulation.
data_collector : dict
A dictionary whose keys are the names of the additional
exported variables and the values are callback functions.
inputs_log : list
List of dictionaries with the inputs used in each simulation.
outputs_log : list
Expand Down Expand Up @@ -86,7 +89,7 @@ def __init__(
rocket,
flight,
export_list=None,
export_function=None,
data_collector=None,
): # pylint: disable=too-many-statements
"""
Initialize a MonteCarlo object.
Expand All @@ -110,11 +113,17 @@ def __init__(
`out_of_rail_stability_margin`, `out_of_rail_time`,
`out_of_rail_velocity`, `max_mach_number`, `frontal_surface_wind`,
`lateral_surface_wind`. Default is None.
export_function : callable, optional
A function which gets called at the end of a simulation to collect
additional data to be exported that isn't pre-defined. Takes the
Flight object as an argument and returns a dictionary. Default is None.
data_collector : dict, optional
A dictionary whose keys are the names of the exported variables
and the values are callback functions. The keys (variable names) must not
overwrite the default names on 'export_list'. The callback functions receive
a Flight object and returns a value of that variable. For instance

.. code-block:: python
custom_data_collector = {
"max_acceleration": lambda flight: max(flight.acceleration(flight.time)),
"date": lambda flight: flight.env.date,
}

Returns
-------
Expand Down Expand Up @@ -143,7 +152,8 @@ def __init__(
self._last_print_len = 0 # used to print on the same line

self.export_list = self.__check_export_list(export_list)
self.export_function = export_function
self._check_data_collector(data_collector)
phmbressan marked this conversation as resolved.
Show resolved Hide resolved
self.data_collector = data_collector

try:
self.import_inputs()
Expand Down Expand Up @@ -371,20 +381,15 @@ def __export_flight_data(
for export_item in self.export_list
}

if self.export_function is not None:
try:
additional_exports = self.export_function(flight)
except Exception as e:
raise ValueError(
"An error was encountered running your custom export function. "
"Check for errors in 'export_function' definition."
) from e

for key in additional_exports.keys():
if key in self.export_list:
if self.data_collector is not None:
additional_exports = {}
for key, callback in self.data_collector.items():
try:
additional_exports[key] = callback(flight)
except Exception as e:
raise ValueError(
f"Invalid export function, returns dict which overwrites key, '{key}'"
)
f"An error was encountered running 'data_collector' callback {key}. "
) from e
results = results | additional_exports

input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
Expand Down Expand Up @@ -494,6 +499,37 @@ def __check_export_list(self, export_list):

return export_list

def _check_data_collector(self, data_collector):
"""Check if data collector provided is a valid

Parameters
----------
data_collector : dict
A dictionary whose keys are the names of the exported variables
and the values are callback functions that receive a Flight object
and returns a value of that variable
"""

if data_collector is not None:

if not isinstance(data_collector, dict):
raise ValueError(
"Invalid 'data_collector' argument! "
"It must be a dict of callback functions."
)

for key, callback in data_collector.items():
if key in self.export_list:
raise ValueError(
"Invalid 'data_collector' key! "
f"Variable names overwrites 'export_list' key '{key}'."
)
if not callable(callback):
raise ValueError(
f"Invalid value in 'data_collector' for key '{key}'! "
"Values must be python callables (callback functions)."
)

def __reprint(self, msg, end="\n", flush=False):
"""
Prints a message on the same line as the previous one and replaces the
Expand Down Expand Up @@ -683,12 +719,9 @@ def set_processed_results(self):
self.processed_results = {}
for result, values in self.results.items():
try:
if isinstance(values[0], float):
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
else:
self.processed_results[result] = (None, None)
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
except TypeError:
self.processed_results[result] = (None, None)

Expand Down
52 changes: 27 additions & 25 deletions tests/integration/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,49 +115,51 @@ def test_monte_carlo_export_ellipses_to_kml(monte_carlo_calisto_pre_loaded):

@pytest.mark.slow
def test_monte_carlo_callback(monte_carlo_calisto):
"""Tests the export_function argument of the MonteCarlo class.
"""Tests the data_collector argument of the MonteCarlo class.

Parameters
----------
monte_carlo_calisto : MonteCarlo
The MonteCarlo object, this is a pytest fixture.
"""

def valid_export_function(flight):
custom_export_dict = {
"name": flight.name,
"density_t0": flight.env.density(0),
}
return custom_export_dict
# define valid data collector
valid_data_collector = {
"name": lambda flight: flight.name,
"density_t0": lambda flight: flight.env.density(0),
}

monte_carlo_calisto.export_function = valid_export_function
monte_carlo_calisto.data_collector = valid_data_collector
# NOTE: this is really slow, it runs 10 flight simulations
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)

# tests if print works when we have None in summary
monte_carlo_calisto.info()

# tests if logical errors in export functions raise errors
def export_function_with_logical_error(flight):
custom_export_dict = {
"date": flight.env.date,
"density_t0": flight.env.density(0) / "0",
}
return custom_export_dict
## tests if an error is raised for invalid data_collector definitions
# invalid type
def invalid_data_collector(flight):
return flight.name

monte_carlo_calisto.export_function = export_function_with_logical_error
# NOTE: this is really slow, it runs 10 flight simulations
with pytest.raises(ValueError):
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# invalid key overwrite
invalid_data_collector = {"apogee": lambda flight: flight.apogee}
with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# tests if overwriting default exports raises errors
def export_function_with_overwriting_error(flight):
custom_export_dict = {
"apogee": flight.apogee,
}
return custom_export_dict
# invalid callback definition
invalid_data_collector = {"name": "Calisto"} # callbacks must be callables!
with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

monte_carlo_calisto.export_function = export_function_with_overwriting_error
# invalid logic (division by zero)
invalid_data_collector = {
"density_t0": lambda flight: flight.env.density(0) / "0",
}
monte_carlo_calisto.data_collector = invalid_data_collector
# NOTE: this is really slow, it runs 10 flight simulations
with pytest.raises(ValueError):
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)

Expand Down
Loading