Skip to content

Commit

Permalink
Move config.py and config_slice.py to config/runtime_params.py
Browse files Browse the repository at this point in the history
…and `config/runtime_params_slice.py`.

We are choosing a new naming scheme for the input parameters to TORAX. Some of the last PRs have also moved the repo in this direction:

 - RuntimeParams refer to the various runtime, JAX-compatible (read: PyTree) input parameters to the various functions/modules in TORAX.

 - RuntimeParamsSlice refers to a "slice" of those params at a single point of time. (It slices both the collection of params and slices in the time-dimension. Double slice!)

 - Config will refer to any setup code that configures the entire TORAX simulation run. I.e. the term "config" will refer to code that configures the `sim.Sim` object.

PiperOrigin-RevId: 628424832
  • Loading branch information
araju authored and Torax team committed Apr 26, 2024
1 parent bf569d6 commit c130b98
Show file tree
Hide file tree
Showing 111 changed files with 2,535 additions and 2,073 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ python3 run_simulation_main.py \

Once complete, the time history of a simulation state and derived quantities is written to `state_history.nc`. The output path is written to stdout

To take advantage of the in-memory (non-persistent) cache, the process does not end upon simulation termination. It is possible to modify the config, toggle the `log_progress` and `plot_progress` flags, and rerun the simulation. Only the following modifications will then trigger a recompilation:
To take advantage of the in-memory (non-persistent) cache, the process does not end upon simulation termination. It is possible to modify the runtime_params, toggle the `log_progress` and `plot_progress` flags, and rerun the simulation. Only the following modifications will then trigger a recompilation:

- Grid resolution
- Evolved variables (equations being solved)
Expand Down
30 changes: 16 additions & 14 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def maybe_update_config_module(
def change_config(
sim: torax.Sim,
config_module_str: str,
) -> tuple[torax.Sim, torax.Config, str]:
) -> tuple[torax.Sim, torax.GeneralRuntimeParams, str]:
"""Returns a new Sim with the updated config but same SimulationStepFn.
This function gives the user a chance to reuse the SimulationStepFn without
Expand All @@ -193,13 +193,13 @@ def change_config(
config_module_str = maybe_update_config_module(config_module_str)
simulation_app.log_to_stdout(
f'Change {config_module_str} to include new values. Only changes to '
'get_config() will be picked up.',
'get_runtime_params() will be picked up.',
color=simulation_app.AnsiColors.BLUE,
)
input('Press Enter when ready.')
config_module, _ = _get_config_module(config_module_str)
new_config = config_module.get_config()
new_geo = config_module.get_geometry(new_config)
new_runtime_params = config_module.get_runtime_params()
new_geo = config_module.get_geometry(new_runtime_params)
new_transport_model = config_module.get_transport_model()
source_models = config_module.get_sources()
new_source_params = {
Expand All @@ -220,18 +220,18 @@ def change_config(
)
sim = simulation_app.update_sim(
sim=sim,
config=new_config,
runtime_params=new_runtime_params,
geo=new_geo,
transport_runtime_params=new_transport_model.runtime_params,
source_runtime_params=new_source_params,
stepper_runtime_params_getter=stepper_params_getter,
)
return sim, new_config, config_module_str
return sim, new_runtime_params, config_module_str


def change_sim_obj(
config_module_str: str,
) -> tuple[torax.Sim, torax.Config, str]:
) -> tuple[torax.Sim, torax.GeneralRuntimeParams, str]:
"""Builds a new Sim from the config module.
Unlike change_config(), this function builds a brand new Sim object with a
Expand All @@ -256,9 +256,9 @@ def change_sim_obj(
)
input('Press Enter when done changing the module.')
config_module, _ = _get_config_module(config_module_str)
new_config = config_module.get_config()
new_runtime_params = config_module.get_runtime_params()
sim = config_module.get_sim()
return sim, new_config, config_module_str
return sim, new_runtime_params, config_module_str


def _toggle_log_progress(log_sim_progress: bool) -> bool:
Expand Down Expand Up @@ -323,14 +323,14 @@ def _toggle_log_output(log_sim_output: bool) -> bool:

def main(_):
config_module, config_module_str = _get_config_module()
new_config = config_module.get_config()
new_runtime_params = config_module.get_runtime_params()
sim = config_module.get_sim()
log_sim_progress = _LOG_SIM_PROGRESS.value
plot_sim_progress = _PLOT_SIM_PROGRESS.value
log_sim_output = _LOG_SIM_OUTPUT.value
simulation_app.main(
lambda: sim,
output_dir=new_config.output_dir,
output_dir=new_runtime_params.output_dir,
log_sim_progress=log_sim_progress,
plot_sim_progress=plot_sim_progress,
log_sim_output=log_sim_output,
Expand All @@ -344,19 +344,21 @@ def main(_):
case _UserCommand.RUN:
simulation_app.main(
lambda: sim,
output_dir=new_config.output_dir,
output_dir=new_runtime_params.output_dir,
log_sim_progress=log_sim_progress,
plot_sim_progress=plot_sim_progress,
log_sim_output=log_sim_output,
)
case _UserCommand.CHANGE_CONFIG:
# See docstring for detailed info on what recompiles.
sim, new_config, config_module_str = change_config(
sim, new_runtime_params, config_module_str = change_config(
sim, config_module_str
)
case _UserCommand.CHANGE_SIM_OBJ:
# This always builds a new object and requires recompilation.
sim, new_config, config_module_str = change_sim_obj(config_module_str)
sim, new_runtime_params, config_module_str = change_sim_obj(
config_module_str
)
case _UserCommand.TOGGLE_LOG_SIM_PROGRESS:
log_sim_progress = _toggle_log_progress(log_sim_progress)
case _UserCommand.TOGGLE_PLOT_SIM_PROGRESS:
Expand Down
14 changes: 7 additions & 7 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import os

import jax
from torax import config
from torax import fvm
from torax import math_utils
from torax import physics
from torax.config import recursive_replace
from torax.config import runtime_params as general_runtime_params
from torax.config.config_args import recursive_replace
from torax.constants import CONSTANTS
from torax.geometry import build_chease_geometry
from torax.geometry import build_circular_geometry
Expand All @@ -43,7 +43,7 @@
from torax.time_step_calculator.time_step_calculator import TimeStepCalculator
# Unsure why but `from torax.config import Config` doesn't work in some
# circumstances.
Config = config.Config
GeneralRuntimeParams = general_runtime_params.GeneralRuntimeParams


# pylint: enable=g-importing-member
Expand All @@ -69,10 +69,10 @@
CANONICAL_ORDER = [
'dt',
'source_type',
'static_config_slice',
'dynamic_config_slice',
'dynamic_config_slice_t',
'dynamic_config_slice_t_plus_dt',
'static_runtime_params_slice',
'dynamic_runtime_params_slice',
'dynamic_runtime_params_slice_t',
'dynamic_runtime_params_slice_t_plus_dt',
'unused_config',
'dynamic_source_runtime_params',
'geo',
Expand Down
Loading

0 comments on commit c130b98

Please sign in to comment.