From 4656c0a5986e5d2ec8cec9e59d73a1043ee6ae3c Mon Sep 17 00:00:00 2001 From: Akhil Raju Date: Sat, 27 Apr 2024 08:04:16 -0700 Subject: [PATCH] Move `config.py` and `config_slice.py` to `config/runtime_params.py` and `config/runtime_params_slice.py`. We are choosing a new naming scheme for the input parameters to TORAX. Some of the last PRs have also moved the repo in this direction: - RuntimeParams refer to the various runtime, JAX-compatible (read: PyTree) input parameters to the various functions/modules in TORAX. - RuntimeParamsSlice refers to a "slice" of those params at a single point of time. (It slices both the collection of params and slices in the time-dimension. Double slice!) - Config will refer to any setup code that configures the entire TORAX simulation run. I.e. the term "config" will refer to code that configures the `sim.Sim` object. PiperOrigin-RevId: 628674938 --- .github/workflows/pytest.yml | 4 +- README.md | 2 +- run_simulation_main.py | 30 +- torax/__init__.py | 14 +- torax/calc_coeffs.py | 210 ++++++------ .../config_args.py} | 83 ++++- torax/{config.py => config/runtime_params.py} | 98 +----- .../runtime_params_slice.py} | 111 ++++--- .../tests/runtime_params.py} | 10 +- .../tests/runtime_params_slice.py} | 110 ++++--- torax/core_profile_setters.py | 265 ++++++++------- torax/fvm/block_1d_coeffs.py | 18 +- torax/fvm/newton_raphson_solve_block.py | 38 +-- torax/fvm/optimizer_solve_block.py | 46 +-- torax/fvm/residual_and_loss.py | 96 +++--- torax/fvm/tests/fvm.py | 169 ++++++---- torax/geometry.py | 58 ++-- torax/interpolated_param.py | 11 +- torax/sim.py | 305 ++++++++++-------- torax/simulation_app.py | 36 ++- torax/sources/bootstrap_current_source.py | 24 +- torax/sources/electron_density_sources.py | 22 +- torax/sources/external_current_source.py | 30 +- torax/sources/formula_config.py | 6 +- torax/sources/formulas.py | 6 +- torax/sources/fusion_heat_source.py | 6 +- torax/sources/generic_ion_el_heat_source.py | 10 +- torax/sources/qei_source.py | 32 +- torax/sources/runtime_params.py | 7 +- torax/sources/source.py | 45 +-- torax/sources/source_models.py | 96 +++--- .../sources/tests/bootstrap_current_source.py | 32 +- .../sources/tests/external_current_source.py | 38 +-- torax/sources/tests/formulas.py | 20 +- torax/sources/tests/fusion_heat_source.py | 30 +- torax/sources/tests/qei_source.py | 54 ++-- torax/sources/tests/source.py | 248 ++++++++------ torax/sources/tests/source_models.py | 56 ++-- torax/sources/tests/test_lib.py | 126 +++++--- torax/spectators/tests/plotting.py | 18 +- torax/state.py | 4 +- torax/stepper/linear_theta_method.py | 22 +- torax/stepper/nonlinear_theta_method.py | 55 ++-- torax/stepper/predictor_corrector_method.py | 33 +- torax/stepper/runtime_params.py | 9 +- torax/stepper/stepper.py | 66 ++-- torax/tests/boundary_conditions.py | 45 +-- torax/tests/geometry.py | 14 +- torax/tests/physics.py | 56 ++-- torax/tests/sim.py | 16 +- torax/tests/sim_custom_sources.py | 32 +- torax/tests/sim_output_source_profiles.py | 76 +++-- torax/tests/sim_time_dependence.py | 65 ++-- torax/tests/state.py | 140 ++++---- .../tests/test_data/compilation_benchmark.py | 22 +- torax/tests/test_data/default_config.py | 20 +- torax/tests/test_data/test_absolute_jext.py | 22 +- .../test_all_transport_crank_nicolson.py | 24 +- .../test_all_transport_fusion_qlknn.py | 22 +- torax/tests/test_data/test_bootstrap.py | 22 +- torax/tests/test_data/test_cgmheat.py | 20 +- torax/tests/test_data/test_chease.py | 22 +- torax/tests/test_data/test_crank_nicolson.py | 22 +- torax/tests/test_data/test_exact_finaltime.py | 22 +- torax/tests/test_data/test_explicit.py | 22 +- torax/tests/test_data/test_fixed_dt.py | 22 +- .../test_data/test_frozen_newton_raphson.py | 24 +- .../tests/test_data/test_frozen_optimizer.py | 24 +- torax/tests/test_data/test_fusion_power.py | 22 +- torax/tests/test_data/test_implicit.py | 22 +- .../test_implicit_short_optimizer.py | 24 +- .../test_data/test_iterbaseline_mockup.py | 24 +- .../tests/test_data/test_iterhybrid_mockup.py | 24 +- .../tests/test_data/test_iterhybrid_newton.py | 24 +- .../test_iterhybrid_predictor_corrector.py | 24 +- .../tests/test_data/test_iterhybrid_rampup.py | 24 +- .../test_data/test_ne_qlknn_deff_veff.py | 22 +- .../test_data/test_ne_qlknn_defromchie.py | 22 +- .../test_data/test_newton_raphson_zeroiter.py | 22 +- torax/tests/test_data/test_ohmic_power.py | 22 +- .../test_data/test_optimizer_zeroiter.py | 22 +- .../test_data/test_particle_sources_cgm.py | 22 +- .../test_particle_sources_constant.py | 22 +- torax/tests/test_data/test_pc_method_ne.py | 22 +- torax/tests/test_data/test_pedestal.py | 20 +- .../test_prescribed_timedependent_ne.py | 22 +- torax/tests/test_data/test_psi_and_heat.py | 22 +- torax/tests/test_data/test_psi_heat_dens.py | 22 +- .../test_data/test_psichease_ip_chease.py | 22 +- .../test_data/test_psichease_ip_parameters.py | 22 +- .../test_psichease_prescribed_johm.py | 22 +- .../test_psichease_prescribed_jtot.py | 22 +- torax/tests/test_data/test_psiequation.py | 22 +- torax/tests/test_data/test_qei.py | 22 +- .../test_data/test_qei_chease_highdens.py | 22 +- torax/tests/test_data/test_qlknnheat.py | 20 +- .../test_data/test_semiimplicit_convection.py | 22 +- torax/tests/test_data/test_timedependence.py | 22 +- torax/tests/test_lib/explicit_stepper.py | 26 +- torax/tests/test_lib/sim_test_case.py | 48 +-- torax/tests/test_lib/torax_refs.py | 39 +-- .../array_time_step_calculator.py | 18 +- .../chi_time_step_calculator.py | 18 +- .../fixed_time_step_calculator.py | 16 +- .../time_step_calculator.py | 10 +- torax/transport_model/constant.py | 33 +- torax/transport_model/critical_gradient.py | 44 +-- torax/transport_model/qlknn_wrapper.py | 48 +-- torax/transport_model/runtime_params.py | 4 +- torax/transport_model/tests/qlknn_wrapper.py | 28 +- .../transport_model/tests/transport_model.py | 42 +-- torax/transport_model/transport_model.py | 41 +-- 112 files changed, 2551 insertions(+), 2095 deletions(-) rename torax/{runtime_params/config_slice_args.py => config/config_args.py} (58%) rename torax/{config.py => config/runtime_params.py} (68%) rename torax/{config_slice.py => config/runtime_params_slice.py} (78%) rename torax/{tests/config.py => config/tests/runtime_params.py} (92%) rename torax/{tests/config_slice.py => config/tests/runtime_params_slice.py} (70%) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index bf93bc99..fa994567 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -68,6 +68,8 @@ jobs: - name: Run core tests run: | pytest \ + torax/config/tests/runtime_params_slice.py \ + torax/config/tests/runtime_params.py \ torax/fvm/tests/fvm.py \ torax/sources/tests/bootstrap_current_source.py \ torax/sources/tests/current_density_sources.py \ @@ -82,8 +84,6 @@ jobs: torax/spectators/tests/plotting.py \ torax/spectators/tests/spectator.py \ torax/tests/boundary_conditions.py \ - torax/tests/config_slice.py \ - torax/tests/config.py \ torax/tests/geometry.py \ torax/tests/interpolated_param.py \ torax/tests/jax_utils.py \ diff --git a/README.md b/README.md index 393c2b7c..ecb0f5ea 100644 --- a/README.md +++ b/README.md @@ -201,7 +201,7 @@ python3 run_simulation_main.py \ Once complete, the time history of a simulation state and derived quantities is written to `state_history.nc`. The output path is written to stdout -To take advantage of the in-memory (non-persistent) cache, the process does not end upon simulation termination. It is possible to modify the config, toggle the `log_progress` and `plot_progress` flags, and rerun the simulation. Only the following modifications will then trigger a recompilation: +To take advantage of the in-memory (non-persistent) cache, the process does not end upon simulation termination. It is possible to modify the runtime_params, toggle the `log_progress` and `plot_progress` flags, and rerun the simulation. Only the following modifications will then trigger a recompilation: - Grid resolution - Evolved variables (equations being solved) diff --git a/run_simulation_main.py b/run_simulation_main.py index 6e876b98..43d4c6ef 100644 --- a/run_simulation_main.py +++ b/run_simulation_main.py @@ -168,7 +168,7 @@ def maybe_update_config_module( def change_config( sim: torax.Sim, config_module_str: str, -) -> tuple[torax.Sim, torax.Config, str]: +) -> tuple[torax.Sim, torax.GeneralRuntimeParams, str]: """Returns a new Sim with the updated config but same SimulationStepFn. This function gives the user a chance to reuse the SimulationStepFn without @@ -193,13 +193,13 @@ def change_config( config_module_str = maybe_update_config_module(config_module_str) simulation_app.log_to_stdout( f'Change {config_module_str} to include new values. Only changes to ' - 'get_config() will be picked up.', + 'get_runtime_params() will be picked up.', color=simulation_app.AnsiColors.BLUE, ) input('Press Enter when ready.') config_module, _ = _get_config_module(config_module_str) - new_config = config_module.get_config() - new_geo = config_module.get_geometry(new_config) + new_runtime_params = config_module.get_runtime_params() + new_geo = config_module.get_geometry(new_runtime_params) new_transport_model = config_module.get_transport_model() source_models = config_module.get_sources() new_source_params = { @@ -220,18 +220,18 @@ def change_config( ) sim = simulation_app.update_sim( sim=sim, - config=new_config, + runtime_params=new_runtime_params, geo=new_geo, transport_runtime_params=new_transport_model.runtime_params, source_runtime_params=new_source_params, stepper_runtime_params_getter=stepper_params_getter, ) - return sim, new_config, config_module_str + return sim, new_runtime_params, config_module_str def change_sim_obj( config_module_str: str, -) -> tuple[torax.Sim, torax.Config, str]: +) -> tuple[torax.Sim, torax.GeneralRuntimeParams, str]: """Builds a new Sim from the config module. Unlike change_config(), this function builds a brand new Sim object with a @@ -256,9 +256,9 @@ def change_sim_obj( ) input('Press Enter when done changing the module.') config_module, _ = _get_config_module(config_module_str) - new_config = config_module.get_config() + new_runtime_params = config_module.get_runtime_params() sim = config_module.get_sim() - return sim, new_config, config_module_str + return sim, new_runtime_params, config_module_str def _toggle_log_progress(log_sim_progress: bool) -> bool: @@ -323,14 +323,14 @@ def _toggle_log_output(log_sim_output: bool) -> bool: def main(_): config_module, config_module_str = _get_config_module() - new_config = config_module.get_config() + new_runtime_params = config_module.get_runtime_params() sim = config_module.get_sim() log_sim_progress = _LOG_SIM_PROGRESS.value plot_sim_progress = _PLOT_SIM_PROGRESS.value log_sim_output = _LOG_SIM_OUTPUT.value simulation_app.main( lambda: sim, - output_dir=new_config.output_dir, + output_dir=new_runtime_params.output_dir, log_sim_progress=log_sim_progress, plot_sim_progress=plot_sim_progress, log_sim_output=log_sim_output, @@ -344,19 +344,21 @@ def main(_): case _UserCommand.RUN: simulation_app.main( lambda: sim, - output_dir=new_config.output_dir, + output_dir=new_runtime_params.output_dir, log_sim_progress=log_sim_progress, plot_sim_progress=plot_sim_progress, log_sim_output=log_sim_output, ) case _UserCommand.CHANGE_CONFIG: # See docstring for detailed info on what recompiles. - sim, new_config, config_module_str = change_config( + sim, new_runtime_params, config_module_str = change_config( sim, config_module_str ) case _UserCommand.CHANGE_SIM_OBJ: # This always builds a new object and requires recompilation. - sim, new_config, config_module_str = change_sim_obj(config_module_str) + sim, new_runtime_params, config_module_str = change_sim_obj( + config_module_str + ) case _UserCommand.TOGGLE_LOG_SIM_PROGRESS: log_sim_progress = _toggle_log_progress(log_sim_progress) case _UserCommand.TOGGLE_PLOT_SIM_PROGRESS: diff --git a/torax/__init__.py b/torax/__init__.py index d8280491..785727fa 100644 --- a/torax/__init__.py +++ b/torax/__init__.py @@ -19,11 +19,11 @@ import os import jax -from torax import config from torax import fvm from torax import math_utils from torax import physics -from torax.config import recursive_replace +from torax.config import runtime_params as general_runtime_params +from torax.config.config_args import recursive_replace from torax.constants import CONSTANTS from torax.geometry import build_chease_geometry from torax.geometry import build_circular_geometry @@ -43,7 +43,7 @@ from torax.time_step_calculator.time_step_calculator import TimeStepCalculator # Unsure why but `from torax.config import Config` doesn't work in some # circumstances. -Config = config.Config +GeneralRuntimeParams = general_runtime_params.GeneralRuntimeParams # pylint: enable=g-importing-member @@ -69,10 +69,10 @@ CANONICAL_ORDER = [ 'dt', 'source_type', - 'static_config_slice', - 'dynamic_config_slice', - 'dynamic_config_slice_t', - 'dynamic_config_slice_t_plus_dt', + 'static_runtime_params_slice', + 'dynamic_runtime_params_slice', + 'dynamic_runtime_params_slice_t', + 'dynamic_runtime_params_slice_t_plus_dt', 'unused_config', 'dynamic_source_runtime_params', 'geo', diff --git a/torax/calc_coeffs.py b/torax/calc_coeffs.py index 4d0ec032..a075f465 100644 --- a/torax/calc_coeffs.py +++ b/torax/calc_coeffs.py @@ -21,12 +21,12 @@ import jax import jax.numpy as jnp -from torax import config_slice from torax import constants from torax import geometry from torax import jax_utils from torax import physics from torax import state +from torax.config import runtime_params_slice from torax.fvm import block_1d_coeffs from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib @@ -34,7 +34,7 @@ def calculate_pereverzev_flux( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: @@ -42,10 +42,10 @@ def calculate_pereverzev_flux( consts = constants.CONSTANTS true_ne_face = ( - core_profiles.ne.face_value() * dynamic_config_slice.numerics.nref + core_profiles.ne.face_value() * dynamic_runtime_params_slice.numerics.nref ) true_ni_face = ( - core_profiles.ni.face_value() * dynamic_config_slice.numerics.nref + core_profiles.ni.face_value() * dynamic_runtime_params_slice.numerics.nref ) geo_factor = jnp.concatenate( @@ -56,7 +56,7 @@ def calculate_pereverzev_flux( geo.g1_over_vpr_face * true_ni_face * consts.keV2J - * dynamic_config_slice.stepper.chi_per + * dynamic_runtime_params_slice.stepper.chi_per / geo.rmax**2 ) @@ -64,11 +64,11 @@ def calculate_pereverzev_flux( geo.g1_over_vpr_face * true_ne_face * consts.keV2J - * dynamic_config_slice.stepper.chi_per + * dynamic_runtime_params_slice.stepper.chi_per / geo.rmax**2 ) - d_face_per_el = dynamic_config_slice.stepper.d_per / geo.rmax + d_face_per_el = dynamic_runtime_params_slice.stepper.d_per / geo.rmax v_face_per_el = ( core_profiles.ne.face_grad() / core_profiles.ne.face_value() @@ -80,16 +80,18 @@ def calculate_pereverzev_flux( # (for PDE stability) chi_face_per_ion = jnp.where( jnp.logical_and( - dynamic_config_slice.profile_conditions.set_pedestal, - geo.r_face_norm > dynamic_config_slice.profile_conditions.Ped_top, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + geo.r_face_norm + > dynamic_runtime_params_slice.profile_conditions.Ped_top, ), 0.0, chi_face_per_ion, ) chi_face_per_el = jnp.where( jnp.logical_and( - dynamic_config_slice.profile_conditions.set_pedestal, - geo.r_face_norm > dynamic_config_slice.profile_conditions.Ped_top, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + geo.r_face_norm + > dynamic_runtime_params_slice.profile_conditions.Ped_top, ), 0.0, chi_face_per_el, @@ -108,8 +110,9 @@ def calculate_pereverzev_flux( d_face_per_el = jnp.where( jnp.logical_and( - dynamic_config_slice.profile_conditions.set_pedestal, - geo.r_face_norm > dynamic_config_slice.profile_conditions.Ped_top, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + geo.r_face_norm + > dynamic_runtime_params_slice.profile_conditions.Ped_top, ), 0.0, d_face_per_el * geo.g1_over_vpr_face / geo.rmax, @@ -117,8 +120,9 @@ def calculate_pereverzev_flux( v_face_per_el = jnp.where( jnp.logical_and( - dynamic_config_slice.profile_conditions.set_pedestal, - geo.r_face_norm > dynamic_config_slice.profile_conditions.Ped_top, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + geo.r_face_norm + > dynamic_runtime_params_slice.profile_conditions.Ped_top, ), 0.0, v_face_per_el * geo.g0_face / geo.rmax, @@ -138,8 +142,8 @@ def calculate_pereverzev_flux( def calc_coeffs( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, transport_model: transport_model_lib.TransportModel, @@ -152,11 +156,11 @@ def calc_coeffs( """Calculates Block1DCoeffs for the time step described by `core_profiles`. Args: - static_config_slice: General input parameters which are fixed through a - simulation run, and if changed, would trigger a recompile. - dynamic_config_slice: General input parameters that can change from time - step to time step or simulation run to run, and do so without triggering a - recompile. + static_runtime_params_slice: General input parameters which are fixed + through a simulation run, and if changed, would trigger a recompile. + dynamic_runtime_params_slice: General input parameters that can change from + time step to time step or simulation run to run, and do so without + triggering a recompile. geo: Geometry describing the torus. core_profiles: Core plasma profiles for this time step during this iteration of the solver. Depending on the type of stepper being used, this may or @@ -185,7 +189,7 @@ def calc_coeffs( # If we are fully implicit and we are making a call for calc_coeffs for the # explicit components of the PDE, only return a cheaper reduced Block1DCoeffs - if explicit_call and static_config_slice.stepper.theta_imp == 1.0: + if explicit_call and static_runtime_params_slice.stepper.theta_imp == 1.0: return _calc_coeffs_reduced( geo, core_profiles, @@ -193,8 +197,8 @@ def calc_coeffs( ) else: return _calc_coeffs_full( - static_config_slice, - dynamic_config_slice, + static_runtime_params_slice, + dynamic_runtime_params_slice, geo, core_profiles, transport_model, @@ -208,15 +212,15 @@ def calc_coeffs( @functools.partial( jax_utils.jit, static_argnames=[ - 'static_config_slice', + 'static_runtime_params_slice', 'transport_model', 'source_models', 'evolving_names', ], ) def _calc_coeffs_full( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, transport_model: transport_model_lib.TransportModel, @@ -228,11 +232,11 @@ def _calc_coeffs_full( """Calculates Block1DCoeffs for the time step described by `core_profiles`. Args: - static_config_slice: General input parameters which are fixed through a - simulation run, and if changed, would trigger a recompile. - dynamic_config_slice: General input parameters that can change from time - step to time step or simulation run to run, and do so without triggering a - recompile. + static_runtime_params_slice: General input parameters which are fixed + through a simulation run, and if changed, would trigger a recompile. + dynamic_runtime_params_slice: General input parameters that can change from + time step to time step or simulation run to run, and do so without + triggering a recompile. geo: Geometry describing the torus. core_profiles: Core plasma profiles for this time step during this iteration of the solver. Depending on the type of stepper being used, this may or @@ -261,8 +265,8 @@ def _calc_coeffs_full( # model the pedestal. mask = physics.internal_boundary( geo, - dynamic_config_slice.profile_conditions.Ped_top, - dynamic_config_slice.profile_conditions.set_pedestal, + dynamic_runtime_params_slice.profile_conditions.Ped_top, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, ) # This only calculates sources set to implicit in the config. All other @@ -270,7 +274,7 @@ def _calc_coeffs_full( # explicit_source_profiles). implicit_source_profiles = source_models_lib.build_source_profiles( source_models=source_models, - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, explicit=False, @@ -281,22 +285,30 @@ def _calc_coeffs_full( # Decide which values to use depending on whether the source is explicit or # implicit. sigma = jax_utils.select( - dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, + dynamic_runtime_params_slice.sources[ + source_models.j_bootstrap_name + ].is_explicit, explicit_source_profiles.j_bootstrap.sigma, implicit_source_profiles.j_bootstrap.sigma, ) j_bootstrap = jax_utils.select( - dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, + dynamic_runtime_params_slice.sources[ + source_models.j_bootstrap_name + ].is_explicit, explicit_source_profiles.j_bootstrap.j_bootstrap, implicit_source_profiles.j_bootstrap.j_bootstrap, ) j_bootstrap_face = jax_utils.select( - dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, + dynamic_runtime_params_slice.sources[ + source_models.j_bootstrap_name + ].is_explicit, explicit_source_profiles.j_bootstrap.j_bootstrap_face, implicit_source_profiles.j_bootstrap.j_bootstrap_face, ) I_bootstrap = jax_utils.select( # pylint: disable=invalid-name - dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, + dynamic_runtime_params_slice.sources[ + source_models.j_bootstrap_name + ].is_explicit, explicit_source_profiles.j_bootstrap.I_bootstrap, implicit_source_profiles.j_bootstrap.I_bootstrap, ) @@ -333,24 +345,24 @@ def _calc_coeffs_full( ) true_ne_face = ( - core_profiles.ne.face_value() * dynamic_config_slice.numerics.nref + core_profiles.ne.face_value() * dynamic_runtime_params_slice.numerics.nref ) true_ni_face = ( - core_profiles.ni.face_value() * dynamic_config_slice.numerics.nref + core_profiles.ni.face_value() * dynamic_runtime_params_slice.numerics.nref ) # Transient term coefficient vector (has radial dependence through r, n) toc_temp_ion = ( - 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.numerics.nref + 1.5 * geo.vpr * consts.keV2J * dynamic_runtime_params_slice.numerics.nref ) tic_temp_ion = core_profiles.ni.value toc_temp_el = ( - 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.numerics.nref + 1.5 * geo.vpr * consts.keV2J * dynamic_runtime_params_slice.numerics.nref ) tic_temp_el = core_profiles.ne.value toc_psi = ( 1.0 - / dynamic_config_slice.numerics.resistivity_mult + / dynamic_runtime_params_slice.numerics.resistivity_mult * geo.r * sigma * consts.mu0 @@ -362,14 +374,16 @@ def _calc_coeffs_full( tic_dens_el = jnp.ones_like(geo.vpr) # Diffusion term coefficients - transport_coeffs = transport_model(dynamic_config_slice, geo, core_profiles) + transport_coeffs = transport_model( + dynamic_runtime_params_slice, geo, core_profiles + ) chi_face_ion = transport_coeffs.chi_face_ion chi_face_el = transport_coeffs.chi_face_el d_face_el = transport_coeffs.d_face_el v_face_el = transport_coeffs.v_face_el d_face_psi = geo.G2_face / geo.J_face / geo.rmax**2 - if static_config_slice.dens_eq: + if static_runtime_params_slice.dens_eq: if d_face_el is None or v_face_el is None: raise NotImplementedError( f'{type(transport_model)} does not support the density equation.' @@ -382,38 +396,38 @@ def _calc_coeffs_full( # transport regions, to avoid transient discontinuities chi_face_ion = jnp.where( jnp.logical_and( - dynamic_config_slice.transport.apply_inner_patch, + dynamic_runtime_params_slice.transport.apply_inner_patch, geo.r_face_norm - < dynamic_config_slice.transport.rho_inner + consts.eps, + < dynamic_runtime_params_slice.transport.rho_inner + consts.eps, ), - dynamic_config_slice.transport.chii_inner, + dynamic_runtime_params_slice.transport.chii_inner, chi_face_ion, ) chi_face_el = jnp.where( jnp.logical_and( - dynamic_config_slice.transport.apply_inner_patch, + dynamic_runtime_params_slice.transport.apply_inner_patch, geo.r_face_norm - < dynamic_config_slice.transport.rho_inner + consts.eps, + < dynamic_runtime_params_slice.transport.rho_inner + consts.eps, ), - dynamic_config_slice.transport.chie_inner, + dynamic_runtime_params_slice.transport.chie_inner, chi_face_el, ) d_face_el = jnp.where( jnp.logical_and( - dynamic_config_slice.transport.apply_inner_patch, + dynamic_runtime_params_slice.transport.apply_inner_patch, geo.r_face_norm - < dynamic_config_slice.transport.rho_inner + consts.eps, + < dynamic_runtime_params_slice.transport.rho_inner + consts.eps, ), - dynamic_config_slice.transport.De_inner, + dynamic_runtime_params_slice.transport.De_inner, d_face_el, ) v_face_el = jnp.where( jnp.logical_and( - dynamic_config_slice.transport.apply_inner_patch, + dynamic_runtime_params_slice.transport.apply_inner_patch, geo.r_face_norm - < dynamic_config_slice.transport.rho_inner + consts.eps, + < dynamic_runtime_params_slice.transport.rho_inner + consts.eps, ), - dynamic_config_slice.transport.Ve_inner, + dynamic_runtime_params_slice.transport.Ve_inner, v_face_el, ) @@ -423,57 +437,57 @@ def _calc_coeffs_full( chi_face_ion = jnp.where( jnp.logical_and( jnp.logical_and( - dynamic_config_slice.transport.apply_outer_patch, + dynamic_runtime_params_slice.transport.apply_outer_patch, jnp.logical_not( - dynamic_config_slice.profile_conditions.set_pedestal + dynamic_runtime_params_slice.profile_conditions.set_pedestal ), ), geo.r_face_norm - > dynamic_config_slice.transport.rho_outer - consts.eps, + > dynamic_runtime_params_slice.transport.rho_outer - consts.eps, ), - dynamic_config_slice.transport.chii_outer, + dynamic_runtime_params_slice.transport.chii_outer, chi_face_ion, ) chi_face_el = jnp.where( jnp.logical_and( jnp.logical_and( - dynamic_config_slice.transport.apply_outer_patch, + dynamic_runtime_params_slice.transport.apply_outer_patch, jnp.logical_not( - dynamic_config_slice.profile_conditions.set_pedestal + dynamic_runtime_params_slice.profile_conditions.set_pedestal ), ), geo.r_face_norm - > dynamic_config_slice.transport.rho_outer - consts.eps, + > dynamic_runtime_params_slice.transport.rho_outer - consts.eps, ), - dynamic_config_slice.transport.chie_outer, + dynamic_runtime_params_slice.transport.chie_outer, chi_face_el, ) d_face_el = jnp.where( jnp.logical_and( jnp.logical_and( - dynamic_config_slice.transport.apply_outer_patch, + dynamic_runtime_params_slice.transport.apply_outer_patch, jnp.logical_not( - dynamic_config_slice.profile_conditions.set_pedestal + dynamic_runtime_params_slice.profile_conditions.set_pedestal ), ), geo.r_face_norm - > dynamic_config_slice.transport.rho_outer - consts.eps, + > dynamic_runtime_params_slice.transport.rho_outer - consts.eps, ), - dynamic_config_slice.transport.De_outer, + dynamic_runtime_params_slice.transport.De_outer, d_face_el, ) v_face_el = jnp.where( jnp.logical_and( jnp.logical_and( - dynamic_config_slice.transport.apply_outer_patch, + dynamic_runtime_params_slice.transport.apply_outer_patch, jnp.logical_not( - dynamic_config_slice.profile_conditions.set_pedestal + dynamic_runtime_params_slice.profile_conditions.set_pedestal ), ), geo.r_face_norm - > dynamic_config_slice.transport.rho_outer - consts.eps, + > dynamic_runtime_params_slice.transport.rho_outer - consts.eps, ), - dynamic_config_slice.transport.Ve_outer, + dynamic_runtime_params_slice.transport.Ve_outer, v_face_el, ) @@ -525,26 +539,26 @@ def _calc_coeffs_full( # calculate neped # pylint: disable=invalid-name nGW = ( - dynamic_config_slice.profile_conditions.Ip + dynamic_runtime_params_slice.profile_conditions.Ip / (jnp.pi * geo.Rmin**2) * 1e20 - / dynamic_config_slice.numerics.nref + / dynamic_runtime_params_slice.numerics.nref ) # pylint: enable=invalid-name neped_unnorm = jnp.where( - dynamic_config_slice.profile_conditions.neped_is_fGW, - dynamic_config_slice.profile_conditions.neped * nGW, - dynamic_config_slice.profile_conditions.neped, + dynamic_runtime_params_slice.profile_conditions.neped_is_fGW, + dynamic_runtime_params_slice.profile_conditions.neped * nGW, + dynamic_runtime_params_slice.profile_conditions.neped, ) source_ne += jnp.where( - dynamic_config_slice.profile_conditions.set_pedestal, - mask * dynamic_config_slice.numerics.largeValue_n * neped_unnorm, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + mask * dynamic_runtime_params_slice.numerics.largeValue_n * neped_unnorm, 0.0, ) source_mat_nn += jnp.where( - dynamic_config_slice.profile_conditions.set_pedestal, - -(mask * dynamic_config_slice.numerics.largeValue_n), + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + -(mask * dynamic_runtime_params_slice.numerics.largeValue_n), 0.0, ) @@ -563,7 +577,7 @@ def _calc_coeffs_full( ) = jax.lax.cond( use_pereverzev, lambda: calculate_pereverzev_flux( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, core_profiles, ), @@ -577,9 +591,9 @@ def _calc_coeffs_full( # Ion and electron heat sources. qei = source_models.qei_source.get_qei( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources[ + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ source_models.qei_source_name ], geo=geo, @@ -632,28 +646,28 @@ def _calc_coeffs_full( # Pedestal source_i += jnp.where( - dynamic_config_slice.profile_conditions.set_pedestal, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, mask - * dynamic_config_slice.numerics.largeValue_T - * dynamic_config_slice.profile_conditions.Tiped, + * dynamic_runtime_params_slice.numerics.largeValue_T + * dynamic_runtime_params_slice.profile_conditions.Tiped, 0.0, ) source_e += jnp.where( - dynamic_config_slice.profile_conditions.set_pedestal, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, mask - * dynamic_config_slice.numerics.largeValue_T - * dynamic_config_slice.profile_conditions.Teped, + * dynamic_runtime_params_slice.numerics.largeValue_T + * dynamic_runtime_params_slice.profile_conditions.Teped, 0.0, ) source_mat_ii -= jnp.where( - dynamic_config_slice.profile_conditions.set_pedestal, - mask * dynamic_config_slice.numerics.largeValue_T, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + mask * dynamic_runtime_params_slice.numerics.largeValue_T, 0.0, ) source_mat_ee -= jnp.where( - dynamic_config_slice.profile_conditions.set_pedestal, - mask * dynamic_config_slice.numerics.largeValue_T, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + mask * dynamic_runtime_params_slice.numerics.largeValue_T, 0.0, ) diff --git a/torax/runtime_params/config_slice_args.py b/torax/config/config_args.py similarity index 58% rename from torax/runtime_params/config_slice_args.py rename to torax/config/config_args.py index 59c5aa5a..ce527248 100644 --- a/torax/runtime_params/config_slice_args.py +++ b/torax/config/config_args.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Functions to help build the arguments to config slice object constructors.""" +"""Functions for building arguments for configs and runtime input params.""" from __future__ import annotations @@ -125,3 +125,84 @@ def get_init_kwargs( config_val = config_val.build_dynamic_params(t) kwargs[field.name] = config_val return kwargs + + +def recursive_replace(obj: ..., **changes) -> ...: + """Recursive version of `dataclasses.replace`. + + This allows updating of nested dataclasses. + Assumes all dict-valued keys in `changes` are themselves changes to apply + to fields of obj. + + Args: + obj: Any dataclass instance. + **changes: Dict of updates to apply to fields of `obj`. + + Returns: + A copy of `obj` with the changes applied. + """ + + flattened_changes = {} + if dataclasses.is_dataclass(obj): + keys_to_types = { + field.name: field.type for field in dataclasses.fields(obj) + } + else: + # obj is another dict-like object that does not have typed fields. + keys_to_types = None + for key, value in changes.items(): + if isinstance(value, dict): + if dataclasses.is_dataclass(getattr(obj, key)): + # If obj[key] is another dataclass, recurse and populate that dataclass + # with the input changes. + flattened_changes[key] = recursive_replace(getattr(obj, key), **value) + elif keys_to_types is not None: + # obj[key] is likely just a dict, and each key needs to be treated + # separately. + # In order to support this, there needs to be some added type + # information for what the values of the dict should be. + typing_args = typing.get_args(keys_to_types[key]) + if len(typing_args) == 2: # the keys type, the values type. + inner_dict = {} + value_type = typing_args[1] + for inner_key, inner_value in value.items(): + if dataclasses.is_dataclass(value_type): + inner_dict[inner_key] = recursive_replace( + value_type(), **inner_value + ) + else: + inner_dict[inner_key] = value_type(inner_value) + flattened_changes[key] = inner_dict + else: + # If we don't have additional type information, just try using the + # value as is. + flattened_changes[key] = value + else: + # keys_to_types is None, so again, we don't have additional information. + flattened_changes[key] = value + else: + # For any value that should be an enum value but is not an enum already + # (could come a YAML file for instance and might be a string or int), + # this converts that value to an enum. + try: + if ( + # if obj is a dataclass + keys_to_types is not None + and + # and this param should be an enum + issubclass(keys_to_types[key], enum.Enum) + and + # but it is not already one. + not isinstance(value, enum.Enum) + ): + if isinstance(value, str): + value = keys_to_types[key][value.upper()] + else: + value = keys_to_types[key](value) + except TypeError: + # Ignore these errors. issubclass doesn't work with typing.Optional + # types. Note that this means that optional enum fields might not be + # cast properly, so avoid these when defining configs. + pass + flattened_changes[key] = value + return dataclasses.replace(obj, **flattened_changes) diff --git a/torax/config.py b/torax/config/runtime_params.py similarity index 68% rename from torax/config.py rename to torax/config/runtime_params.py index 30fb5046..9c4335e4 100644 --- a/torax/config.py +++ b/torax/config/runtime_params.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Definition of configuration dataclass. +"""General runtime input parameters used throughout TORAX simulations.""" -Specifies parameter names and default values for all physics and solver -parameters. -""" +from __future__ import annotations import dataclasses -import enum -import typing + import chex from torax import interpolated_param # Type-alias for clarity. While the InterpolatedParams can vary across any -# field, in Config, we mainly use it to handle time-dependent parameters. +# field, in here, we mainly use it to handle time-dependent parameters. TimeDependentField = interpolated_param.InterpParamOrInterpParamInput -# Type-alias for brevity. Helps users only import this module. +# Type-alias for brevity. InterpolationMode = interpolated_param.InterpolationMode InterpolationParam = interpolated_param.InterpolatedParam @@ -176,8 +173,8 @@ class Numerics: # NOMUTANTS -- It's expected for the tests to pass with different defaults. @chex.dataclass -class Config: - """Configuration parameters for the `torax` module.""" +class GeneralRuntimeParams: + """General runtime input parameters for the `torax` module.""" plasma_composition: PlasmaComposition = dataclasses.field( default_factory=PlasmaComposition @@ -204,84 +201,3 @@ def sanity_check(self) -> None: def __post_init__(self): self.sanity_check() - - -def recursive_replace(obj: ..., **changes) -> ...: - """Recursive version of `dataclasses.replace`. - - This allows updating of nested dataclasses. - Assumes all dict-valued keys in `changes` are themselves changes to apply - to fields of obj. - - Args: - obj: Any dataclass instance. - **changes: Dict of updates to apply to fields of `obj`. - - Returns: - A copy of `obj` with the changes applied. - """ - - flattened_changes = {} - if dataclasses.is_dataclass(obj): - keys_to_types = { - field.name: field.type for field in dataclasses.fields(obj) - } - else: - # obj is another dict-like object that does not have typed fields. - keys_to_types = None - for key, value in changes.items(): - if isinstance(value, dict): - if dataclasses.is_dataclass(getattr(obj, key)): - # If obj[key] is another dataclass, recurse and populate that dataclass - # with the input changes. - flattened_changes[key] = recursive_replace(getattr(obj, key), **value) - elif keys_to_types is not None: - # obj[key] is likely just a dict, and each key needs to be treated - # separately. - # In order to support this, there needs to be some added type - # information for what the values of the dict should be. - typing_args = typing.get_args(keys_to_types[key]) - if len(typing_args) == 2: # the keys type, the values type. - inner_dict = {} - value_type = typing_args[1] - for inner_key, inner_value in value.items(): - if dataclasses.is_dataclass(value_type): - inner_dict[inner_key] = recursive_replace( - value_type(), **inner_value - ) - else: - inner_dict[inner_key] = value_type(inner_value) - flattened_changes[key] = inner_dict - else: - # If we don't have additional type information, just try using the - # value as is. - flattened_changes[key] = value - else: - # keys_to_types is None, so again, we don't have additional information. - flattened_changes[key] = value - else: - # For any value that should be an enum value but is not an enum already - # (could come a YAML file for instance and might be a string or int), - # this converts that value to an enum. - try: - if ( - # if obj is a dataclass - keys_to_types is not None - and - # and this param should be an enum - issubclass(keys_to_types[key], enum.Enum) - and - # but it is not already one. - not isinstance(value, enum.Enum) - ): - if isinstance(value, str): - value = keys_to_types[key][value.upper()] - else: - value = keys_to_types[key](value) - except TypeError: - # Ignore these errors. issubclass doesn't work with typing.Optional - # types. Note that this means that optional enum fields might not be - # cast properly, so avoid these when defining configs. - pass - flattened_changes[key] = value - return dataclasses.replace(obj, **flattened_changes) diff --git a/torax/config_slice.py b/torax/config/runtime_params_slice.py similarity index 78% rename from torax/config_slice.py rename to torax/config/runtime_params_slice.py index 7321dfa9..d2047ce7 100644 --- a/torax/config_slice.py +++ b/torax/config/runtime_params_slice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Inputs to the TORAX steppers based on the input config. +"""Inputs to TORAX steppers and functions based on the input runtime parameters. When running a TORAX simulation, the steppers are (by default) JAX-compiled function, meaning it has two types of arguments: "dynamic" and "static". @@ -20,7 +20,7 @@ The "dynamic" arguments can change from call to call. These arguments must be arrays, scalars, or standard (possibly nested) Python containers. See the JAX docs for more info on allowed types. They cannot influence the logical branches -the JointStateStepper may take (again, see the sharp bits in the JAX docs to +the SimulationStepFn may take (again, see the sharp bits in the JAX docs to learn more about the how these "dynamic" args can be used within the function). Note that the "dynamic" arguments are NOT necessarily time-dependent. They do @@ -41,8 +41,8 @@ from typing import Callable import chex -from torax import config as config_lib -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params as general_runtime_params from torax.sources import runtime_params as sources_params from torax.stepper import runtime_params as stepper_params from torax.transport_model import runtime_params as transport_model_params @@ -54,13 +54,13 @@ @chex.dataclass(frozen=True) -class DynamicConfigSlice: +class DynamicRuntimeParamsSlice: """Input params that are ok to use as inputs to a JAX-compiled function. - This PyTree of params is input to the sim.JointStateStepper, which updates + This PyTree of params is input to the sim.SimulationStepFn, which updates the joint state and evolves the mesh state. This config includes various "dynamic" parameters which can change from step to step, or from - simulation run to simulation run, without requiring the JointStateStepper to + simulation run to simulation run, without requiring the SimulationStepFn to recompile. Note that "dynamic" does NOT mean time dependent necessarily (though these @@ -70,6 +70,16 @@ class DynamicConfigSlice: While the parameters are not necessarily time-dependent, that is how the class gets its name: a config "slice" refers to a subset of the overall TORAX config at a specific time t. + + This class contains "slices" of various RuntimeParams attributes defined + throughout TORAX: + - from the "general" runtime params + - from the transport model's runtime params + - from the stepper's runtime params + - from each of the sources' runtime params + + This class packages all these together for convenience, as it simplifies many + of the internal APIs within TORAX. """ transport: transport_model_params.DynamicRuntimeParams @@ -204,10 +214,10 @@ class DynamicNumerics: @chex.dataclass(frozen=True) -class StaticConfigSlice: - """Static arguments to JointStateStepper which cannot be changed. +class StaticRuntimeParamsSlice: + """Static arguments to SimulationStepFn which cannot be changed. - If any changes are made to these arguments, then the JointStateStepper must be + If any changes are made to these arguments, then the SimulationStepFn must be recompiled. NOTE: These are not the only parameters which can trigger a recompile! For @@ -237,52 +247,53 @@ class StaticConfigSlice: # iteratively at successively lower dt until convergence is reached adaptive_dt: bool + # pylint: enable=invalid-name -def build_dynamic_config_slice( - config: config_lib.Config, +def build_dynamic_runtime_params_slice( + runtime_params: general_runtime_params.GeneralRuntimeParams, transport: transport_model_params.RuntimeParams | None = None, sources: dict[str, sources_params.RuntimeParams] | None = None, stepper: stepper_params.RuntimeParams | None = None, t: chex.Numeric | None = None, -) -> DynamicConfigSlice: - """Builds a DynamicConfigSlice based on the input config.""" +) -> DynamicRuntimeParamsSlice: + """Builds a DynamicRuntimeParamsSlice.""" transport = transport or transport_model_params.RuntimeParams() sources = sources or {} stepper = stepper or stepper_params.RuntimeParams() - t = config.numerics.t_initial if t is None else t - # For each dataclass attribute under DynamicConfigSlice, build those objects - # explicitly, and then for all scalar attributes, fetch their values directly - # from the input config using config_slice_args.get_init_kwargs. - return DynamicConfigSlice( + t = runtime_params.numerics.t_initial if t is None else t + # For each dataclass attribute under DynamicRuntimeParamsSlice, build those + # objects explicitly, and then for all scalar attributes, fetch their values + # directly from the input runtime params using config_args.get_init_kwargs. + return DynamicRuntimeParamsSlice( transport=transport.build_dynamic_params(t), stepper=stepper.build_dynamic_params(t), sources=_build_dynamic_sources(sources, t), plasma_composition=DynamicPlasmaComposition( - **config_slice_args.get_init_kwargs( - input_config=config.plasma_composition, + **config_args.get_init_kwargs( + input_config=runtime_params.plasma_composition, output_type=DynamicPlasmaComposition, t=t, ) ), profile_conditions=DynamicProfileConditions( - **config_slice_args.get_init_kwargs( - input_config=config.profile_conditions, + **config_args.get_init_kwargs( + input_config=runtime_params.profile_conditions, output_type=DynamicProfileConditions, t=t, ) ), numerics=DynamicNumerics( - **config_slice_args.get_init_kwargs( - input_config=config.numerics, + **config_args.get_init_kwargs( + input_config=runtime_params.numerics, output_type=DynamicNumerics, t=t, ) ), - **config_slice_args.get_init_kwargs( - input_config=config, - output_type=DynamicConfigSlice, + **config_args.get_init_kwargs( + input_config=runtime_params, + output_type=DynamicRuntimeParamsSlice, t=t, skip=( 'transport', @@ -307,43 +318,43 @@ def _build_dynamic_sources( } -def build_static_config_slice( - config: config_lib.Config, +def build_static_runtime_params_slice( + runtime_params: general_runtime_params.GeneralRuntimeParams, stepper: stepper_params.RuntimeParams | None = None, -) -> StaticConfigSlice: - """Builds a StaticConfigSlice based on the input config.""" +) -> StaticRuntimeParamsSlice: + """Builds a StaticRuntimeParamsSlice.""" # t set to None because there shouldnt be time-dependent params in the static # config. stepper = stepper or stepper_params.RuntimeParams() - return StaticConfigSlice( + return StaticRuntimeParamsSlice( stepper=stepper.build_static_params(), - nr=config.numerics.nr, - ion_heat_eq=config.numerics.ion_heat_eq, - el_heat_eq=config.numerics.el_heat_eq, - current_eq=config.numerics.current_eq, - dens_eq=config.numerics.dens_eq, - adaptive_dt=config.numerics.adaptive_dt, + nr=runtime_params.numerics.nr, + ion_heat_eq=runtime_params.numerics.ion_heat_eq, + el_heat_eq=runtime_params.numerics.el_heat_eq, + current_eq=runtime_params.numerics.current_eq, + dens_eq=runtime_params.numerics.dens_eq, + adaptive_dt=runtime_params.numerics.adaptive_dt, ) -class DynamicConfigSliceProvider: - """Provides a DynamicConfigSlice to use during time t of the sim. +class DynamicRuntimeParamsSliceProvider: + """Provides a DynamicRuntimeParamsSlice to use during time t of the sim. - The DynamicConfigSlice may change from time step to time step, so this class - interpolates any time-dependent params in the input config to the values they - should be at time t. + The DynamicRuntimeParamsSlice may change from time step to time step, so this + class interpolates any time-dependent params in the input config to the values + they should be at time t. See `run_simulation()` for how this callable is used. """ def __init__( self, - config: config_lib.Config, + runtime_params: general_runtime_params.GeneralRuntimeParams, transport_getter: Callable[[], transport_model_params.RuntimeParams], sources_getter: Callable[[], dict[str, sources_params.RuntimeParams]], stepper_getter: Callable[[], stepper_params.RuntimeParams], ): - self._input_config = config + self._runtime_params = runtime_params self._transport_runtime_params_getter = transport_getter self._sources_getter = sources_getter self._stepper_getter = stepper_getter @@ -351,10 +362,10 @@ def __init__( def __call__( self, t: chex.Numeric, - ) -> DynamicConfigSlice: - """Returns a DynamicConfigSlice to use during time t of the sim.""" - return build_dynamic_config_slice( - config=self._input_config, + ) -> DynamicRuntimeParamsSlice: + """Returns a DynamicRuntimeParamsSlice to use during time t of the sim.""" + return build_dynamic_runtime_params_slice( + runtime_params=self._runtime_params, transport=self._transport_runtime_params_getter(), sources=self._sources_getter(), stepper=self._stepper_getter(), diff --git a/torax/tests/config.py b/torax/config/tests/runtime_params.py similarity index 92% rename from torax/tests/config.py rename to torax/config/tests/runtime_params.py index 47d0d3a5..276a5a2e 100644 --- a/torax/tests/config.py +++ b/torax/config/tests/runtime_params.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for torax.config.""" +"""Unit tests for torax.config.runtime_params.""" import dataclasses from absl.testing import absltest from absl.testing import parameterized -from torax import config as config_lib +from torax.config import config_args -class ConfigTest(parameterized.TestCase): - """Unit tests for the `torax.config` module.""" +class RuntimeParamsTest(parameterized.TestCase): + """Unit tests for the `torax.config.runtime_params` module.""" def test_recursive_replace(self): """Basic test of recursive replace.""" @@ -66,7 +66,7 @@ def test_recursive_replace(self): # Don't update a4, to test that it is untouched } - result = config_lib.recursive_replace(instance, **changes) + result = config_args.recursive_replace(instance, **changes) self.assertIsInstance(result, A) self.assertEqual(result.a1, -1) diff --git a/torax/tests/config_slice.py b/torax/config/tests/runtime_params_slice.py similarity index 70% rename from torax/tests/config_slice.py rename to torax/config/tests/runtime_params_slice.py index e04c271c..2e3242ec 100644 --- a/torax/tests/config_slice.py +++ b/torax/config/tests/runtime_params_slice.py @@ -12,94 +12,100 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for torax.config_slice.""" +"""Unit tests for torax.config.runtime_params_slice.""" from absl.testing import absltest from absl.testing import parameterized import jax import numpy as np -from torax import config as config_lib -from torax import config_slice as config_slice_lib +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice as runtime_params_slice_lib from torax.sources import electron_density_sources from torax.sources import external_current_source from torax.sources import formula_config -from torax.sources import runtime_params +from torax.sources import runtime_params as sources_params_lib from torax.stepper import runtime_params as stepper_params_lib from torax.transport_model import runtime_params as transport_params_lib -class ConfigSliceTest(parameterized.TestCase): - """Unit tests for the `config_slice` module.""" +class RuntimeParamsSliceTest(parameterized.TestCase): + """Unit tests for the `runtime_params_slice` module.""" def test_dynamic_slice_can_be_input_to_jitted_function(self): - def foo(config_slice: config_slice_lib.DynamicConfigSlice): - _ = config_slice # do nothing. + """Tests that the slice can be input to a jitted function.""" + + def foo( + runtime_params_slice: runtime_params_slice_lib.DynamicRuntimeParamsSlice, + ): + _ = runtime_params_slice # do nothing. foo_jitted = jax.jit(foo) - config = config_lib.Config() - dynamic_slice = config_slice_lib.build_dynamic_config_slice(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + dynamic_slice = runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params, + ) # Make sure you can call the function with dynamic_slice as an arg. foo_jitted(dynamic_slice) def test_time_dependent_provider_is_time_dependent(self): - """Tests that the config slice provider is time dependent.""" - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + """Tests that the runtime_params slice provider is time dependent.""" + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_right={0.0: 2.0, 4.0: 4.0}, ), ) - provider = config_slice_lib.DynamicConfigSliceProvider( - config=config, + provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, transport_getter=transport_params_lib.RuntimeParams, sources_getter=lambda: {}, stepper_getter=stepper_params_lib.RuntimeParams, ) - dynamic_config_slice = provider(t=1.0) + dynamic_runtime_params_slice = provider(t=1.0) np.testing.assert_allclose( - dynamic_config_slice.profile_conditions.Ti_bound_right, 2.5 + dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 2.5 ) - dynamic_config_slice = provider(t=2.0) + dynamic_runtime_params_slice = provider(t=2.0) np.testing.assert_allclose( - dynamic_config_slice.profile_conditions.Ti_bound_right, 3.0 + dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 3.0 ) def test_boundary_conditions_are_time_dependent(self): """Tests that the boundary conditions are time dependent params.""" # All of the following parameters are time-dependent fields, but they can # be initialized in different ways. - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_right={0.0: 2.0, 4.0: 4.0}, Te_bound_right=4.5, # not time-dependent. - ne_bound_right=config_lib.InterpolationParam( + ne_bound_right=general_runtime_params.InterpolationParam( {5.0: 6.0, 7.0: 8.0}, - interpolation_mode=config_lib.InterpolationMode.STEP, + interpolation_mode=general_runtime_params.InterpolationMode.STEP, ), ), ) np.testing.assert_allclose( - config_slice_lib.build_dynamic_config_slice( - config, t=2.0 + runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params, t=2.0 ).profile_conditions.Ti_bound_right, 3.0, ) np.testing.assert_allclose( - config_slice_lib.build_dynamic_config_slice( - config, t=4.0 + runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params, t=4.0 ).profile_conditions.Te_bound_right, 4.5, ) np.testing.assert_allclose( - config_slice_lib.build_dynamic_config_slice( - config, t=6.0 + runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params, t=6.0 ).profile_conditions.ne_bound_right, 6.0, ) def test_pedestal_is_time_dependent(self): - """Tests that the pedestal config params are time dependent.""" - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + """Tests that the pedestal runtime params are time dependent.""" + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal={0.0: True, 1.0: False}, Tiped={0.0: 0.0, 1.0: 1.0}, Teped={0.0: 1.0, 1.0: 2.0}, @@ -108,7 +114,9 @@ def test_pedestal_is_time_dependent(self): ), ) # Check at time 0. - dcs = config_slice_lib.build_dynamic_config_slice(config, t=0.0) + dcs = runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params, t=0.0 + ) profile_conditions = dcs.profile_conditions np.testing.assert_allclose(profile_conditions.set_pedestal, True) np.testing.assert_allclose(profile_conditions.Tiped, 0.0) @@ -116,7 +124,9 @@ def test_pedestal_is_time_dependent(self): np.testing.assert_allclose(profile_conditions.neped, 2.0) np.testing.assert_allclose(profile_conditions.Ped_top, 3.0) # And check after the time limit. - dcs = config_slice_lib.build_dynamic_config_slice(config, t=1.0) + dcs = runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params, t=1.0 + ) profile_conditions = dcs.profile_conditions np.testing.assert_allclose(profile_conditions.set_pedestal, False) np.testing.assert_allclose(profile_conditions.Tiped, 1.0) @@ -127,11 +137,11 @@ def test_pedestal_is_time_dependent(self): def test_source_formula_config_has_time_dependent_params(self): """Tests that the source formula config params are time-dependent.""" with self.subTest('default_ne_sources'): - # Check that the config params for the default ne sources are + # Check that the runtime params for the default ne sources are # time-dependent. - config = config_lib.Config() - dcs = config_slice_lib.build_dynamic_config_slice( - config=config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + dcs = runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, sources={ 'gas_puff_source': electron_density_sources.GasPuffRuntimeParams( puff_decay_length={0.0: 0.0, 1.0: 4.0}, @@ -184,11 +194,11 @@ def test_source_formula_config_has_time_dependent_params(self): ) with self.subTest('exponential_formula'): - config = config_lib.Config() - dcs = config_slice_lib.build_dynamic_config_slice( - config=config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + dcs = runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, sources={ - 'gas_puff_source': runtime_params.RuntimeParams( + 'gas_puff_source': sources_params_lib.RuntimeParams( formula=formula_config.Exponential( total={0.0: 0.0, 1.0: 1.0}, c1={0.0: 0.0, 1.0: 2.0}, @@ -211,11 +221,11 @@ def test_source_formula_config_has_time_dependent_params(self): ) with self.subTest('gaussian_formula'): - config = config_lib.Config() - dcs = config_slice_lib.build_dynamic_config_slice( - config=config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + dcs = runtime_params_slice_lib.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, sources={ - 'gas_puff_source': runtime_params.RuntimeParams( + 'gas_puff_source': sources_params_lib.RuntimeParams( formula=formula_config.Gaussian( total={0.0: 0.0, 1.0: 1.0}, c1={0.0: 0.0, 1.0: 2.0}, @@ -236,11 +246,11 @@ def test_source_formula_config_has_time_dependent_params(self): dcs.sources['gas_puff_source'].formula.c2, 0.75 ) - def test_wext_in_dynamic_config_cannot_be_negative(self): + def test_wext_in_dynamic_runtime_params_cannot_be_negative(self): """Tests that wext cannot be negative.""" - config = config_lib.Config() - dcs_provider = config_slice_lib.DynamicConfigSliceProvider( - config=config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + dcs_provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, transport_getter=transport_params_lib.RuntimeParams, sources_getter=lambda: { 'jext': external_current_source.RuntimeParams( diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index d51cbcfa..528e01e3 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -22,7 +22,6 @@ import dataclasses import jax from jax import numpy as jnp -from torax import config_slice from torax import constants from torax import fvm from torax import geometry @@ -30,6 +29,7 @@ from torax import math_utils from torax import physics from torax import state +from torax.config import runtime_params_slice from torax.geometry import Geometry # pylint: disable=g-importing-member from torax.sources import external_current_source from torax.sources import source_models as source_models_lib @@ -39,23 +39,24 @@ def _updated_ti( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, ) -> fvm.CellVariable: """Updated ion temp. Used upon initialization and if temp_ion=False.""" # pylint: disable=invalid-name Ti_bound_left = jax_utils.error_if_not_positive( - dynamic_config_slice.profile_conditions.Ti_bound_left, 'Ti_bound_left' + dynamic_runtime_params_slice.profile_conditions.Ti_bound_left, + 'Ti_bound_left', ) Ti_bound_right = jax_utils.error_if_not_positive( - dynamic_config_slice.profile_conditions.Ti_bound_right, + dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 'Ti_bound_right', ) temp_ion_face = jnp.linspace( start=Ti_bound_left, stop=Ti_bound_right, - num=static_config_slice.nr + 1, + num=static_runtime_params_slice.nr + 1, ) temp_ion = geometry.face_to_cell(temp_ion_face) temp_ion = fvm.CellVariable( @@ -70,23 +71,24 @@ def _updated_ti( def _updated_te( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, ) -> fvm.CellVariable: """Updated electron temp. Used upon initialization and if temp_el=False.""" # pylint: disable=invalid-name Te_bound_left = jax_utils.error_if_not_positive( - dynamic_config_slice.profile_conditions.Te_bound_left, 'Te_bound_left' + dynamic_runtime_params_slice.profile_conditions.Te_bound_left, + 'Te_bound_left', ) Te_bound_right = jax_utils.error_if_not_positive( - dynamic_config_slice.profile_conditions.Te_bound_right, + dynamic_runtime_params_slice.profile_conditions.Te_bound_right, 'Te_bound_right', ) temp_el_face = jnp.linspace( start=Te_bound_left, stop=Te_bound_right, - num=static_config_slice.nr + 1, + num=static_runtime_params_slice.nr + 1, ) temp_el = geometry.face_to_cell(temp_el_face) temp_el = fvm.CellVariable( @@ -101,35 +103,35 @@ def _updated_te( def _updated_dens( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, ) -> tuple[fvm.CellVariable, fvm.CellVariable]: """Updated particle density. Used upon initialization and if dens_eq=False.""" # pylint: disable=invalid-name nGW = ( - dynamic_config_slice.profile_conditions.Ip + dynamic_runtime_params_slice.profile_conditions.Ip / (jnp.pi * geo.Rmin**2) * 1e20 - / dynamic_config_slice.numerics.nref + / dynamic_runtime_params_slice.numerics.nref ) nbar_unnorm = jnp.where( - dynamic_config_slice.profile_conditions.nbar_is_fGW, - dynamic_config_slice.profile_conditions.nbar * nGW, - dynamic_config_slice.profile_conditions.nbar, + dynamic_runtime_params_slice.profile_conditions.nbar_is_fGW, + dynamic_runtime_params_slice.profile_conditions.nbar * nGW, + dynamic_runtime_params_slice.profile_conditions.nbar, ) # calculate ne_bound_right ne_bound_right = jnp.where( - dynamic_config_slice.profile_conditions.ne_bound_right_is_fGW, - dynamic_config_slice.profile_conditions.ne_bound_right * nGW, - dynamic_config_slice.profile_conditions.ne_bound_right, + dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_fGW, + dynamic_runtime_params_slice.profile_conditions.ne_bound_right * nGW, + dynamic_runtime_params_slice.profile_conditions.ne_bound_right, ) # set peaking (limited to linear profile) nshape_face = jnp.linspace( - dynamic_config_slice.profile_conditions.npeak, + dynamic_runtime_params_slice.profile_conditions.npeak, 1, - static_config_slice.nr + 1, + static_runtime_params_slice.nr + 1, ) nshape = geometry.face_to_cell(nshape_face) @@ -153,8 +155,8 @@ def _updated_dens( # Zeff = (ni + Zimp**2 * nimp)/ne ; nimp*Zimp + ni = ne dilution_factor = physics.get_main_ion_dilution_factor( - dynamic_config_slice.plasma_composition.Zimp, - dynamic_config_slice.plasma_composition.Zeff, + dynamic_runtime_params_slice.plasma_composition.Zimp, + dynamic_runtime_params_slice.plasma_composition.Zeff, ) ni = fvm.CellVariable( @@ -167,14 +169,14 @@ def _updated_dens( def _prescribe_currents_no_bootstrap( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, source_models: source_models_lib.SourceModels, ) -> state.Currents: """Creates the initial Currents without the bootstrap current. Args: - dynamic_config_slice: General configuration parameters at t_initial. + dynamic_runtime_params_slice: General runtime parameters at t_initial. geo: Geometry of the tokamak. source_models: All TORAX source/sink functions. @@ -186,10 +188,12 @@ def _prescribe_currents_no_bootstrap( # notational conventions rather than on Google Python style # pylint: disable=invalid-name - # Calculate splitting of currents depending on config - Ip = dynamic_config_slice.profile_conditions.Ip + # Calculate splitting of currents depending on input runtime params. + Ip = dynamic_runtime_params_slice.profile_conditions.Ip - dynamic_jext_params = _get_jext_params(dynamic_config_slice, source_models) + dynamic_jext_params = _get_jext_params( + dynamic_runtime_params_slice, source_models + ) if dynamic_jext_params.use_absolute_jext: Iext = dynamic_jext_params.Iext else: @@ -205,7 +209,7 @@ def _prescribe_currents_no_bootstrap( # calculate "External" current profile (e.g. ECCD) # form of external current on face grid jext_face, jext = source_models.jext.get_value( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_jext_params, geo=geo, ) @@ -213,10 +217,10 @@ def _prescribe_currents_no_bootstrap( # construct prescribed current formula on grid. jformula_face = ( 1 - geo.r_face_norm**2 - ) ** dynamic_config_slice.profile_conditions.nu + ) ** dynamic_runtime_params_slice.profile_conditions.nu # calculate total and Ohmic current profiles denom = _trapz(jformula_face * geo.spr_face, geo.r_face) - if dynamic_config_slice.profile_conditions.initial_j_is_total_current: + if dynamic_runtime_params_slice.profile_conditions.initial_j_is_total_current: Ctot = Ip * 1e6 / denom jtot_face = jformula_face * Ctot johm_face = jtot_face - jext_face @@ -229,7 +233,7 @@ def _prescribe_currents_no_bootstrap( johm = geometry.face_to_cell(johm_face) jtot_hires = _get_jtot_hires( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_jext_params, geo, bootstrap_profile, @@ -255,7 +259,7 @@ def _prescribe_currents_no_bootstrap( def _prescribe_currents_with_bootstrap( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, temp_ion: fvm.CellVariable, temp_el: fvm.CellVariable, @@ -268,7 +272,7 @@ def _prescribe_currents_with_bootstrap( """Creates the initial Currents. Args: - dynamic_config_slice: General configuration parameters at t_initial. + dynamic_runtime_params_slice: General runtime parameters at t_initial. geo: Geometry of the tokamak. temp_ion: Ion temperature. temp_el: Electron temperature. @@ -286,11 +290,11 @@ def _prescribe_currents_with_bootstrap( # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style # pylint: disable=invalid-name - Ip = dynamic_config_slice.profile_conditions.Ip + Ip = dynamic_runtime_params_slice.profile_conditions.Ip bootstrap_profile = source_models.j_bootstrap.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources[ + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ source_models.j_bootstrap_name ], geo=geo, @@ -303,8 +307,10 @@ def _prescribe_currents_with_bootstrap( ) f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6) - # Calculate splitting of currents depending on config - dynamic_jext_params = _get_jext_params(dynamic_config_slice, source_models) + # Calculate splitting of currents depending on input runtime params + dynamic_jext_params = _get_jext_params( + dynamic_runtime_params_slice, source_models + ) if dynamic_jext_params.use_absolute_jext: Iext = dynamic_jext_params.Iext else: @@ -314,7 +320,7 @@ def _prescribe_currents_with_bootstrap( # calculate "External" current profile (e.g. ECCD) # form of external current on face grid jext_face, jext = source_models.jext.get_value( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_jext_params, geo=geo, ) @@ -322,10 +328,10 @@ def _prescribe_currents_with_bootstrap( # construct prescribed current formula on grid. jformula_face = ( 1 - geo.r_face_norm**2 - ) ** dynamic_config_slice.profile_conditions.nu + ) ** dynamic_runtime_params_slice.profile_conditions.nu denom = _trapz(jformula_face * geo.spr_face, geo.r_face) # calculate total and Ohmic current profiles - if dynamic_config_slice.profile_conditions.initial_j_is_total_current: + if dynamic_runtime_params_slice.profile_conditions.initial_j_is_total_current: Ctot = Ip * 1e6 / denom jtot_face = jformula_face * Ctot johm_face = jtot_face - jext_face - bootstrap_profile.j_bootstrap_face @@ -338,7 +344,7 @@ def _prescribe_currents_with_bootstrap( johm = geometry.face_to_cell(johm_face) jtot_hires = _get_jtot_hires( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_jext_params, geo, bootstrap_profile, @@ -364,7 +370,7 @@ def _prescribe_currents_with_bootstrap( def _calculate_currents_from_psi( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, temp_ion: fvm.CellVariable, temp_el: fvm.CellVariable, @@ -376,7 +382,7 @@ def _calculate_currents_from_psi( """Creates the initial Currents using psi to calculate jtot. Args: - dynamic_config_slice: General configuration parameters at t_initial. + dynamic_runtime_params_slice: General runtime parameters at t_initial. geo: Geometry of the tokamak. temp_ion: Ion temperature. temp_el: Electron temperature. @@ -393,7 +399,7 @@ def _calculate_currents_from_psi( # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style # pylint: disable=invalid-name - Ip = dynamic_config_slice.profile_conditions.Ip + Ip = dynamic_runtime_params_slice.profile_conditions.Ip jtot, jtot_face = physics.calc_jtot_from_psi( geo, @@ -401,8 +407,8 @@ def _calculate_currents_from_psi( ) bootstrap_profile = source_models.j_bootstrap.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources[ + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ source_models.j_bootstrap_name ], geo=geo, @@ -414,8 +420,10 @@ def _calculate_currents_from_psi( psi=psi, ) - # Calculate splitting of currents depending on config - dynamic_jext_params = _get_jext_params(dynamic_config_slice, source_models) + # Calculate splitting of currents depending on input runtime params. + dynamic_jext_params = _get_jext_params( + dynamic_runtime_params_slice, source_models + ) if dynamic_jext_params.use_absolute_jext: Iext = dynamic_jext_params.Iext else: @@ -426,7 +434,7 @@ def _calculate_currents_from_psi( # calculate "External" current profile (e.g. ECCD) # form of external current on face grid jext_face, jext = source_models.jext.get_value( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_jext_params, geo=geo, ) @@ -439,7 +447,7 @@ def _calculate_currents_from_psi( # should be summing over all sources that can contribute current i.e. ECCD, # ICRH, NBI, LHCD. jtot_hires = _get_jtot_hires( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_jext_params, geo, bootstrap_profile, @@ -466,7 +474,7 @@ def _calculate_currents_from_psi( def _update_psi_from_j( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, currents: state.Currents, ) -> fvm.CellVariable: @@ -476,7 +484,7 @@ def _update_psi_from_j( integration. Args: - dynamic_config_slice: Dynamic configuration parameters. + dynamic_runtime_params_slice: Dynamic runtime parameters. geo: Torus geometry. currents: Currents structure including high resolution version of jtot. @@ -485,7 +493,7 @@ def _update_psi_from_j( """ psi_constraint = ( - dynamic_config_slice.profile_conditions.Ip + dynamic_runtime_params_slice.profile_conditions.Ip * 1e6 * constants.CONSTANTS.mu0 / geo.G2_face[-1] @@ -524,16 +532,16 @@ def _update_psi_from_j( def initial_core_profiles( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, source_models: source_models_lib.SourceModels, ) -> state.CoreProfiles: """Calculates the initial core profiles. Args: - static_config_slice: Static simulation configuration parameters. - dynamic_config_slice: Dynamic configuration parameters at t=t_initial. + static_runtime_params_slice: Static simulation runtime parameters. + dynamic_runtime_params_slice: Dynamic runtime parameters at t=t_initial. geo: Torus geometry. source_models: All models for TORAX sources/sinks. @@ -544,33 +552,39 @@ def initial_core_profiles( # To set initial values and compute the boundary conditions, we need to handle # potentially time-varying inputs from the users. - # The default time in build_dynamic_config_slice is t_initial - temp_ion = _updated_ti(static_config_slice, dynamic_config_slice, geo) - temp_el = _updated_te(static_config_slice, dynamic_config_slice, geo) - ne, ni = _updated_dens(static_config_slice, dynamic_config_slice, geo) + # The default time in build_dynamic_runtime_params_slice is t_initial + temp_ion = _updated_ti( + static_runtime_params_slice, dynamic_runtime_params_slice, geo + ) + temp_el = _updated_te( + static_runtime_params_slice, dynamic_runtime_params_slice, geo + ) + ne, ni = _updated_dens( + static_runtime_params_slice, dynamic_runtime_params_slice, geo + ) # set up initial psi profile based on current profile if ( isinstance(geo, geometry.CircularGeometry) - or dynamic_config_slice.profile_conditions.initial_psi_from_j + or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j ): # set up initial current profile without bootstrap current, to get # q-profile approximation (needed for bootstrap) currents_no_bootstrap = _prescribe_currents_no_bootstrap( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) psi_no_bootstrap = _update_psi_from_j( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, currents_no_bootstrap, ) # second iteration, with bootstrap current currents = _prescribe_currents_with_bootstrap( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, temp_ion=temp_ion, temp_el=temp_el, @@ -582,7 +596,7 @@ def initial_core_profiles( ) psi = _update_psi_from_j( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, currents, ) @@ -591,19 +605,19 @@ def initial_core_profiles( geo=geo, psi=psi, jtot_face=currents.jtot_face, - q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, ) s_face = physics.calc_s_from_psi(geo, psi) elif ( isinstance(geo, geometry.CHEASEGeometry) - and not dynamic_config_slice.profile_conditions.initial_psi_from_j + and not dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j ): # psi is already provided from the CHEASE equilibrium, so no need to first # calculate currents. However, non-inductive currents are still calculated # and used in current diffusion equation. psi_constraint = ( - dynamic_config_slice.profile_conditions.Ip + dynamic_runtime_params_slice.profile_conditions.Ip * 1e6 * constants.CONSTANTS.mu0 / geo.G2_face[-1] @@ -618,13 +632,13 @@ def initial_core_profiles( geo=geo, psi=psi, jtot_face=geo.jtot_face, - q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, ) s_face = physics.calc_s_from_psi(geo, psi) # calculation external currents currents = _calculate_currents_from_psi( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, temp_ion=temp_ion, @@ -658,7 +672,7 @@ def initial_core_profiles( psidot = dataclasses.replace( psidot, value=source_models_lib.calc_psidot( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, core_profiles, source_models, @@ -671,7 +685,7 @@ def initial_core_profiles( core_profiles = physics.update_jtot_q_face_s_face( geo=geo, core_profiles=core_profiles, - q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, ) # pylint: enable=invalid-name @@ -679,8 +693,8 @@ def initial_core_profiles( def updated_prescribed_core_profiles( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, core_profiles: state.CoreProfiles, ) -> dict[str, jax.Array]: @@ -689,8 +703,8 @@ def updated_prescribed_core_profiles( Uses same functions as for profile initialization. Args: - static_config_slice: Static simulation configuration parameters. - dynamic_config_slice: Dynamic configuration parameters at t=t_initial. + static_runtime_params_slice: Static simulation runtime parameters. + dynamic_runtime_params_slice: Dynamic runtime parameters at t=t_initial. geo: Torus geometry. core_profiles: Core profiles dataclass to be updated @@ -700,26 +714,32 @@ def updated_prescribed_core_profiles( # pylint: disable=invalid-name # If profiles are not evolved, they can still potential be time-evolving, - # depending on the config. If so, they are updated below. + # depending on the runtime params. If so, they are updated below. if ( - not static_config_slice.ion_heat_eq - and dynamic_config_slice.numerics.enable_prescribed_profile_evolution + not static_runtime_params_slice.ion_heat_eq + and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution ): - temp_ion = _updated_ti(static_config_slice, dynamic_config_slice, geo).value + temp_ion = _updated_ti( + static_runtime_params_slice, dynamic_runtime_params_slice, geo + ).value else: temp_ion = core_profiles.temp_ion.value if ( - not static_config_slice.el_heat_eq - and dynamic_config_slice.numerics.enable_prescribed_profile_evolution + not static_runtime_params_slice.el_heat_eq + and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution ): - temp_el = _updated_te(static_config_slice, dynamic_config_slice, geo).value + temp_el = _updated_te( + static_runtime_params_slice, dynamic_runtime_params_slice, geo + ).value else: temp_el = core_profiles.temp_el.value if ( - not static_config_slice.dens_eq - and dynamic_config_slice.numerics.enable_prescribed_profile_evolution + not static_runtime_params_slice.dens_eq + and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution ): - ne, _ = _updated_dens(static_config_slice, dynamic_config_slice, geo) + ne, _ = _updated_dens( + static_runtime_params_slice, dynamic_runtime_params_slice, geo + ) ne = ne.value else: ne = core_profiles.ne.value @@ -729,7 +749,7 @@ def updated_prescribed_core_profiles( def update_evolving_core_profiles( x_new: tuple[fvm.cell_variable.CellVariable, ...], - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, core_profiles: state.CoreProfiles, evolving_names: tuple[str, ...], ) -> state.CoreProfiles: @@ -737,7 +757,7 @@ def update_evolving_core_profiles( Args: x_new: The new values of the evolving variables. - dynamic_config_slice: The dynamic config slice. + dynamic_runtime_params_slice: The dynamic runtime params slice. core_profiles: The old set of core plasma profiles. evolving_names: The names of the evolving variables. """ @@ -757,8 +777,8 @@ def get_update(x_new, var): core_profiles.ni, value=ne.value * physics.get_main_ion_dilution_factor( - dynamic_config_slice.plasma_composition.Zimp, - dynamic_config_slice.plasma_composition.Zeff, + dynamic_runtime_params_slice.plasma_composition.Zimp, + dynamic_runtime_params_slice.plasma_composition.Zeff, ), ) @@ -773,13 +793,13 @@ def get_update(x_new, var): def compute_boundary_conditions( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, ) -> dict[str, dict[str, jax.Array | None]]: """Computes boundary conditions for time t and returns updates to State. Args: - dynamic_config_slice: Runtime configuration at time t. + dynamic_runtime_params_slice: Runtime parameters at time t. geo: Geometry object Returns: @@ -787,27 +807,29 @@ def compute_boundary_conditions( each CellVariable in the state. This dict can in theory recursively replace values in a State object. """ - Ip = dynamic_config_slice.profile_conditions.Ip # pylint: disable=invalid-name + Ip = dynamic_runtime_params_slice.profile_conditions.Ip # pylint: disable=invalid-name Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name - dynamic_config_slice.profile_conditions.Ti_bound_right, 'Ti_bound_right' + dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, + 'Ti_bound_right', ) Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name - dynamic_config_slice.profile_conditions.Te_bound_right, 'Te_bound_right' + dynamic_runtime_params_slice.profile_conditions.Te_bound_right, + 'Te_bound_right', ) # calculate ne_bound_right # pylint: disable=invalid-name nGW = ( - dynamic_config_slice.profile_conditions.Ip + dynamic_runtime_params_slice.profile_conditions.Ip / (jnp.pi * geo.Rmin**2) * 1e20 - / dynamic_config_slice.numerics.nref + / dynamic_runtime_params_slice.numerics.nref ) # pylint: enable=invalid-name ne_bound_right = jnp.where( - dynamic_config_slice.profile_conditions.ne_bound_right_is_fGW, - dynamic_config_slice.profile_conditions.ne_bound_right * nGW, - dynamic_config_slice.profile_conditions.ne_bound_right, + dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_fGW, + dynamic_runtime_params_slice.profile_conditions.ne_bound_right * nGW, + dynamic_runtime_params_slice.profile_conditions.ne_bound_right, ) # define ion profile based on (flat) Zeff and single assumed impurity # with Zimp. main ion limited to hydrogenic species for now. @@ -815,8 +837,8 @@ def compute_boundary_conditions( # Zeff = (ni + Zimp**2 * nimp)/ne ; nimp*Zimp + ni = ne dilution_factor = physics.get_main_ion_dilution_factor( - dynamic_config_slice.plasma_composition.Zimp, - dynamic_config_slice.plasma_composition.Zeff, + dynamic_runtime_params_slice.plasma_composition.Zimp, + dynamic_runtime_params_slice.plasma_composition.Zeff, ) return { 'temp_ion': dict( @@ -852,7 +874,7 @@ def compute_boundary_conditions( # pylint: disable=invalid-name def _get_jtot_hires( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_jext_params: external_current_source.DynamicRuntimeParams, geo: Geometry, bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile, @@ -866,7 +888,7 @@ def _get_jtot_hires( # calculate hi-res "External" current profile (e.g. ECCD) on cell grid. jext_hires = jext_source.jext_hires( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_jext_params, geo=geo, ) @@ -874,10 +896,12 @@ def _get_jtot_hires( # calculate high resolution jtot and Ohmic current profile jformula_hires = ( 1 - geo.r_hires_norm**2 - ) ** dynamic_config_slice.profile_conditions.nu + ) ** dynamic_runtime_params_slice.profile_conditions.nu denom = _trapz(jformula_hires * geo.spr_hires, geo.r_hires) - if dynamic_config_slice.profile_conditions.initial_j_is_total_current: - Ctot_hires = dynamic_config_slice.profile_conditions.Ip * 1e6 / denom + if dynamic_runtime_params_slice.profile_conditions.initial_j_is_total_current: + Ctot_hires = ( + dynamic_runtime_params_slice.profile_conditions.Ip * 1e6 / denom + ) jtot_hires = jformula_hires * Ctot_hires else: Cohm_hires = Iohm * 1e6 / denom @@ -887,16 +911,19 @@ def _get_jtot_hires( def _get_jext_params( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, source_models: source_models_lib.SourceModels, ) -> external_current_source.DynamicRuntimeParams: """Returns dynamic runtime params for the external current source.""" - assert source_models.jext_name in dynamic_config_slice.sources, ( - f'{source_models.jext_name} not found in dynamic_config_slice.sources.' - ' Check to make sure the DynamicConfigSlice was built with `sources`' - ' that include the external current source.' - ) - dynamic_jext_params = dynamic_config_slice.sources[source_models.jext_name] + assert source_models.jext_name in dynamic_runtime_params_slice.sources, ( + f'{source_models.jext_name} not found in' + ' dynamic_runtime_params_slice.sources. Check to make sure the' + ' DynamicRuntimeParamsSlice was built with `sources` that include the' + ' external current source.' + ) + dynamic_jext_params = dynamic_runtime_params_slice.sources[ + source_models.jext_name + ] assert isinstance( dynamic_jext_params, external_current_source.DynamicRuntimeParams ) diff --git a/torax/fvm/block_1d_coeffs.py b/torax/fvm/block_1d_coeffs.py index 25a86f3d..340ff392 100644 --- a/torax/fvm/block_1d_coeffs.py +++ b/torax/fvm/block_1d_coeffs.py @@ -24,7 +24,7 @@ import chex import jax -from torax import config_slice +from torax.config import runtime_params_slice from torax.fvm import cell_variable @@ -80,11 +80,11 @@ class Block1DCoeffs: channel i on the face grid. source_mat_cell: 2-D matrix of Tuples, with source_mat_cell[i][j] adding to block-row i a term of the form source_cell[j] * u[channel j]. Depending on - the source config, may be constant values for a timestep, or updated - iteratively with new states in a nonlinear solver - source_cell: Additional source terms on the cell grid for each channel. - Depending on the source config, may be constant values for a timestep, or + the source runtime_params, may be constant values for a timestep, or updated iteratively with new states in a nonlinear solver + source_cell: Additional source terms on the cell grid for each channel. + Depending on the source runtime_params, may be constant values for a + timestep, or updated iteratively with new states in a nonlinear solver auxiliary_outputs: Optional extra output which can include auxiliary state or information useful for inspecting the computation inside the callback which calculated these coeffs. @@ -103,7 +103,7 @@ class Block1DCoeffsCallback(Protocol): def __call__( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, x: tuple[cell_variable.CellVariable, ...], allow_pereverzev: bool = False, explicit_call: bool = False, @@ -124,9 +124,9 @@ def __call__( final output x_new. Args: - dynamic_config_slice: Runtime configuration parameters. These values are - potentially time-dependent and should correspond to the time step of the - state x. + dynamic_runtime_params_slice: Runtime configuration parameters. These + values are potentially time-dependent and should correspond to the time + step of the state x. x: The state. allow_pereverzev: If True, then the coeffs are being called for an initial guess based on a linear step as opposed to just passing the iniitial diff --git a/torax/fvm/newton_raphson_solve_block.py b/torax/fvm/newton_raphson_solve_block.py index a8522a4f..32d36815 100644 --- a/torax/fvm/newton_raphson_solve_block.py +++ b/torax/fvm/newton_raphson_solve_block.py @@ -24,11 +24,11 @@ import jax from jax import numpy as jnp import numpy as np -from torax import config_slice from torax import fvm from torax import geometry from torax import jax_utils from torax import state as state_module +from torax.config import runtime_params_slice from torax.fvm import block_1d_coeffs from torax.fvm import cell_variable from torax.fvm import fvm_conversions @@ -90,9 +90,9 @@ def _log_iterations( def newton_raphson_solve_block( dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], core_profiles_t_plus_dt: state_module.CoreProfiles, @@ -138,12 +138,12 @@ def newton_raphson_solve_block( Args: dt: Discrete time step. - static_config_slice: Static runtime configuration. Changes to these config - params will trigger recompilation. - dynamic_config_slice_t: Runtime configuration for time t (the start time of - the step). These config params can change from step to step without - triggering a recompilation. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + static_runtime_params_slice: Static runtime parameters. Changes to these + runtime params will trigger recompilation. + dynamic_runtime_params_slice_t: Runtime parameters for time t (the start + time of the step). These config params can change from step to step + without triggering a recompilation. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + dt. geo: Geometry object. x_old: Tuple containing CellVariables for each channel with their values at the start of the time step. @@ -187,7 +187,7 @@ def newton_raphson_solve_block( # pyformat: enable coeffs_old = coeffs_callback( - dynamic_config_slice_t, x_old, explicit_call=True + dynamic_runtime_params_slice_t, x_old, explicit_call=True ) match initial_guess_mode: @@ -195,10 +195,10 @@ def newton_raphson_solve_block( # corrector method if predictor_corrector=True in the solver config case InitialGuessMode.LINEAR: # returns transport coefficients with additional pereverzev terms - # if set by config, needed if stiff transport models (e.g. qlknn) + # if set by runtime_params, needed if stiff transport models (e.g. qlknn) # are used. coeffs_exp_linear = coeffs_callback( - dynamic_config_slice_t, + dynamic_runtime_params_slice_t, x_old, allow_pereverzev=True, explicit_call=True, @@ -222,8 +222,8 @@ def newton_raphson_solve_block( ) init_x_new, _ = predictor_corrector_method.predictor_corrector_method( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, x_old=x_old, init_val=init_val, coeffs_exp=coeffs_exp_linear, @@ -244,8 +244,8 @@ def newton_raphson_solve_block( residual_fun = functools.partial( residual_and_loss.theta_method_block_residual, dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -258,8 +258,8 @@ def newton_raphson_solve_block( jacobian_fun = functools.partial( residual_and_loss.theta_method_block_jacobian, dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, diff --git a/torax/fvm/optimizer_solve_block.py b/torax/fvm/optimizer_solve_block.py index 2af1d876..4a468ce9 100644 --- a/torax/fvm/optimizer_solve_block.py +++ b/torax/fvm/optimizer_solve_block.py @@ -17,10 +17,10 @@ """ import jax -from torax import config_slice from torax import fvm from torax import geometry from torax import state +from torax.config import runtime_params_slice from torax.fvm import block_1d_coeffs from torax.fvm import cell_variable from torax.fvm import fvm_conversions @@ -45,9 +45,9 @@ def optimizer_solve_block( dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], core_profiles_t_plus_dt: state.CoreProfiles, @@ -72,18 +72,18 @@ def optimizer_solve_block( Args: dt: Discrete time step. - static_config_slice: Static runtime configuration. Changes to these config - params will trigger recompilation. A key parameter in static_config slice - is theta_imp, a coefficient in [0, 1] determining which solution method to - use. We solve transient_coeff (x_new - x_old) / dt = theta_imp F(t_new) + - (1 - theta_imp) F(t_old). Three values of theta_imp correspond to named - solution methods: theta_imp = 1: Backward Euler implicit method (default). - theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit - method. - dynamic_config_slice_t: Runtime configuration for time t (the start time of - the step). These config params can change from step to step without + static_runtime_params_slice: Static runtime parameters. Changes to these + runtime params will trigger recompilation. A key parameter in this params + slice is theta_imp, a coefficient in [0, 1] determining which solution + method to use. We solve transient_coeff (x_new - x_old) / dt = theta_imp + F(t_new) + (1 - theta_imp) F(t_old). Three values of theta_imp correspond + to named solution methods: theta_imp = 1: Backward Euler implicit method + (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler + explicit method. + dynamic_runtime_params_slice_t: Runtime params for time t (the start time of + the step). These runtime params can change from step to step without triggering a recompilation. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + dynamic_runtime_params_slice_t_plus_dt: Runtime params for time t + dt. geo: Geometry object used to initialize auxiliary outputs. x_old: Tuple containing CellVariables for each channel with their values at the start of the time step. @@ -116,18 +116,18 @@ def optimizer_solve_block( # pyformat: enable coeffs_old = coeffs_callback( - dynamic_config_slice_t, x_old, explicit_call=True + dynamic_runtime_params_slice_t, x_old, explicit_call=True ) match initial_guess_mode: # LINEAR initial guess will provide the initial guess using the predictor- - # corrector method if predictor_corrector=True in the solver config + # corrector method if predictor_corrector=True in the stepper runtime params case InitialGuessMode.LINEAR: # returns transport coefficients with additional pereverzev terms - # if set by config, needed if stiff transport models (e.g. qlknn) + # if set by runtime_params, needed if stiff transport models (e.g. qlknn) # are used. coeffs_exp_linear = coeffs_callback( - dynamic_config_slice_t, + dynamic_runtime_params_slice_t, x_old, allow_pereverzev=True, explicit_call=True, @@ -150,8 +150,8 @@ def optimizer_solve_block( ) init_x_new, _ = predictor_corrector_method.predictor_corrector_method( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, init_val=init_val, x_old=x_old, coeffs_exp=coeffs_exp_linear, @@ -168,8 +168,8 @@ def optimizer_solve_block( # Advance jaxopt_solver by one timestep x_new_vec, final_loss, aux_output = residual_and_loss.jaxopt_solver( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, x_old=x_old, init_x_new_vec=init_x_new_vec, diff --git a/torax/fvm/residual_and_loss.py b/torax/fvm/residual_and_loss.py index 76e1cc77..2a5e469f 100644 --- a/torax/fvm/residual_and_loss.py +++ b/torax/fvm/residual_and_loss.py @@ -25,11 +25,11 @@ from jax import numpy as jnp import jaxopt from torax import calc_coeffs -from torax import config_slice from torax import core_profile_setters from torax import geometry from torax import jax_utils from torax import state +from torax.config import runtime_params_slice from torax.fvm import block_1d_coeffs from torax.fvm import cell_variable from torax.fvm import discrete_system @@ -193,7 +193,7 @@ def theta_method_matrix_equation( @functools.partial( jax_utils.jit, static_argnames=[ - 'static_config_slice', + 'static_runtime_params_slice', 'transport_model', 'source_models', 'evolving_names', @@ -202,8 +202,8 @@ def theta_method_matrix_equation( def theta_method_block_residual( x_new_guess_vec: jax.Array, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], core_profiles_t_plus_dt: state.CoreProfiles, @@ -219,15 +219,15 @@ def theta_method_block_residual( x_new_guess_vec: Flattened array of current guess of x_new for all evolving core profiles. dt: Time step duration. - static_config_slice: Static runtime configuration. Changes to these config - params will trigger recompilation. A key parameter in static_config slice - is theta_imp, a coefficient in [0, 1] determining which solution method to - use. We solve transient_coeff (x_new - x_old) / dt = theta_imp F(t_new) + - (1 - theta_imp) F(t_old). Three values of theta_imp correspond to named - solution methods: theta_imp = 1: Backward Euler implicit method (default). - theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit - method. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + static_runtime_params_slice: Static runtime parameters. Changes to these + runtime params will trigger recompilation. A key parameter in this params + slice is theta_imp, a coefficient in [0, 1] determining which solution + method to use. We solve transient_coeff (x_new - x_old) / dt = theta_imp + F(t_new) + (1 - theta_imp) F(t_old). Three values of theta_imp correspond + to named solution methods: theta_imp = 1: Backward Euler implicit method + (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler + explicit method. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + dt. geo: Geometry object. x_old: The starting x defined as a tuple of CellVariables. core_profiles_t_plus_dt: Core plasma profiles which contain all available @@ -254,13 +254,13 @@ def theta_method_block_residual( ) core_profiles_t_plus_dt = core_profile_setters.update_evolving_core_profiles( x_new_guess, - dynamic_config_slice_t_plus_dt, + dynamic_runtime_params_slice_t_plus_dt, core_profiles_t_plus_dt, evolving_names, ) coeffs_new = calc_coeffs.calc_coeffs( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles=core_profiles_t_plus_dt, transport_model=transport_model, @@ -276,9 +276,9 @@ def theta_method_block_residual( x_new_guess=x_new_guess, coeffs_old=coeffs_old, coeffs_new=coeffs_new, - theta_imp=static_config_slice.stepper.theta_imp, - convection_dirichlet_mode=static_config_slice.stepper.convection_dirichlet_mode, - convection_neumann_mode=static_config_slice.stepper.convection_neumann_mode, + theta_imp=static_runtime_params_slice.stepper.theta_imp, + convection_dirichlet_mode=static_runtime_params_slice.stepper.convection_dirichlet_mode, + convection_neumann_mode=static_runtime_params_slice.stepper.convection_neumann_mode, ) lhs = jnp.dot(lhs_mat, x_new_guess_vec) + lhs_vec @@ -294,7 +294,7 @@ def theta_method_block_residual( theta_method_block_jacobian = jax_utils.jit( theta_method_block_jacobian, static_argnames=[ - 'static_config_slice', + 'static_runtime_params_slice', 'transport_model', 'source_models', 'evolving_names', @@ -305,7 +305,7 @@ def theta_method_block_residual( @functools.partial( jax_utils.jit, static_argnames=[ - 'static_config_slice', + 'static_runtime_params_slice', 'transport_model', 'source_models', 'evolving_names', @@ -314,8 +314,8 @@ def theta_method_block_residual( def theta_method_block_loss( x_new_guess_vec: jax.Array, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], core_profiles_t_plus_dt: state.CoreProfiles, @@ -331,15 +331,15 @@ def theta_method_block_loss( x_new_guess_vec: Flattened array of current guess of x_new for all evolving core profiles. dt: Time step duration. - static_config_slice: Static runtime configuration. Changes to these config - params will trigger recompilation. A key parameter in static_config slice - is theta_imp, a coefficient in [0, 1] determining which solution method to - use. We solve transient_coeff (x_new - x_old) / dt = theta_imp F(t_new) + - (1 - theta_imp) F(t_old). Three values of theta_imp correspond to named - solution methods: theta_imp = 1: Backward Euler implicit method (default). - theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit - method. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + static_runtime_params_slice: Static runtime parameters. Changes to these + runtime params will trigger recompilation. A key parameter in this params + slice is theta_imp, a coefficient in [0, 1] determining which solution + method to use. We solve transient_coeff (x_new - x_old) / dt = theta_imp + F(t_new) + (1 - theta_imp) F(t_old). Three values of theta_imp correspond + to named solution methods: theta_imp = 1: Backward Euler implicit method + (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler + explicit method. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + dt. geo: geometry object x_old: The starting x defined as a tuple of CellVariables. core_profiles_t_plus_dt: Core plasma profiles which contain all available @@ -361,8 +361,8 @@ def theta_method_block_loss( residual, aux_output = theta_method_block_residual( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, x_old=x_old, x_new_guess_vec=x_new_guess_vec, @@ -380,7 +380,7 @@ def theta_method_block_loss( @functools.partial( jax_utils.jit, static_argnames=[ - 'static_config_slice', + 'static_runtime_params_slice', 'transport_model', 'source_models', 'evolving_names', @@ -388,8 +388,8 @@ def theta_method_block_loss( ) def jaxopt_solver( dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], init_x_new_vec: jax.Array, @@ -406,15 +406,15 @@ def jaxopt_solver( Args: dt: Time step duration. - static_config_slice: Static runtime configuration. Changes to these config - params will trigger recompilation. A key parameter in static_config slice - is theta_imp, a coefficient in [0, 1] determining which solution method to - use. We solve transient_coeff (x_new - x_old) / dt = theta_imp F(t_new) + - (1 - theta_imp) F(t_old). Three values of theta_imp correspond to named - solution methods: theta_imp = 1: Backward Euler implicit method (default). - theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit - method. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + static_runtime_params_slice: Static runtime parameters. Changes to these + runtime params will trigger recompilation. A key parameter in this params + slice is theta_imp, a coefficient in [0, 1] determining which solution + method to use. We solve transient_coeff (x_new - x_old) / dt = theta_imp + F(t_new) + (1 - theta_imp) F(t_old). Three values of theta_imp correspond + to named solution methods: theta_imp = 1: Backward Euler implicit method + (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler + explicit method. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + dt. geo: geometry object. x_old: The starting x defined as a tuple of CellVariables. init_x_new_vec: Flattened array of initial guess of x_new for all evolving @@ -442,8 +442,8 @@ def jaxopt_solver( loss = functools.partial( theta_method_block_loss, dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index b485e1a7..50504752 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -22,11 +22,11 @@ from jax import numpy as jnp import numpy as np from torax import calc_coeffs -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import fvm from torax import geometry +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.fvm import implicit_solve_block from torax.fvm import residual_and_loss from torax.sources import default_sources @@ -43,7 +43,9 @@ class FVMTest(torax_refs.ReferenceValueTest): @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_face_grad( self, @@ -59,7 +61,9 @@ def test_face_grad( @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_underconstrained( self, @@ -69,7 +73,7 @@ def test_underconstrained( references = references_getter() # Use ref_config to configure size, so we can also use ref_geo - value = jnp.zeros(references.config.numerics.nr) + value = jnp.zeros(references.runtime_params.numerics.nr) cell_variable = fvm.CellVariable(value=value, dr=references.geo.dr) # Underconstrain the left with self.assertRaises(AssertionError): @@ -89,7 +93,9 @@ def test_underconstrained( @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_overconstrained( self, @@ -99,7 +105,7 @@ def test_overconstrained( references = references_getter() # Use ref_config to configure size, so we can also use ref_geo - value = jnp.zeros(references.config.numerics.nr) + value = jnp.zeros(references.runtime_params.numerics.nr) cell_variable = fvm.CellVariable(value=value, dr=references.geo.dr) # Overconstrain the left with self.assertRaises(AssertionError): @@ -125,7 +131,7 @@ def test_overconstrained( ), dict( seed=20221114, - references_getter=torax_refs.chease_references_Ip_from_config, + references_getter=torax_refs.chease_references_Ip_from_runtime_params, ), ]) def test_face_grad_constraints(self, seed, references_getter): @@ -133,7 +139,7 @@ def test_face_grad_constraints(self, seed, references_getter): references = references_getter() # Use ref_config to configure size, so we can also use ref_geo - dim = references.config.numerics.nr + dim = references.runtime_params.numerics.nr value = jnp.zeros(dim) rng_state = jax.random.PRNGKey(seed) @@ -350,11 +356,11 @@ def test_nonlinear_solve_block_loss_minimum( self, num_cells, theta_imp, time_steps ): """Tests that the linear solution for a linear problem yields zero residual and loss.""" - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( nr=num_cells, el_heat_eq=False, ), @@ -363,7 +369,7 @@ def test_nonlinear_solve_block_loss_minimum( predictor_corrector=False, theta_imp=theta_imp, ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) transport_model = constant_transport_model.ConstantTransportModel( runtime_params=constant_transport_model.RuntimeParams( chimin=0, @@ -381,29 +387,36 @@ def test_nonlinear_solve_block_loss_minimum( source_models.sources['ohmic_heat_source'].runtime_params.mode = ( source_runtime_params.Mode.ZERO ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - transport=transport_model.runtime_params, - sources=source_models.runtime_params, - stepper=stepper_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + transport=transport_model.runtime_params, + sources=source_models.runtime_params, + stepper=stepper_params, + ) ) - static_config_slice = config_slice.build_static_config_slice( - config, stepper=stepper_params + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params, stepper=stepper_params + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice, dynamic_config_slice, geo, source_models + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + source_models, ) evolving_names = tuple(['temp_ion']) explicit_source_profiles = source_models_lib.build_source_profiles( source_models=source_models, - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, explicit=True, ) coeffs = calc_coeffs.calc_coeffs( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, transport_model=transport_model, @@ -434,8 +447,8 @@ def test_nonlinear_solve_block_loss_minimum( # core_profiles_t_plus_dt is not updated since coeffs stay constant here loss, _ = residual_and_loss.theta_method_block_loss( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice, geo=geo, x_old=x_old, x_new_guess_vec=jnp.concatenate([var.value for var in x_new]), @@ -449,8 +462,8 @@ def test_nonlinear_solve_block_loss_minimum( residual, _ = residual_and_loss.theta_method_block_residual( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice, geo=geo, x_new_guess_vec=jnp.concatenate([var.value for var in x_new]), x_old=x_old, @@ -470,11 +483,11 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): # Create a system with diffusive transport and no sources. When initialized # flat, x_new should remain zero unless boundary conditions change. num_cells = 2 - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( nr=num_cells, el_heat_eq=False, ), @@ -500,22 +513,29 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): source_models.sources['ohmic_heat_source'].runtime_params.mode = ( source_runtime_params.Mode.ZERO ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - transport=transport_model.runtime_params, - sources=source_models.runtime_params, - stepper=stepper_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + transport=transport_model.runtime_params, + sources=source_models.runtime_params, + stepper=stepper_params, + ) ) - static_config_slice = config_slice.build_static_config_slice( - config, stepper=stepper_params + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params, stepper=stepper_params + ) ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) source_models = source_models_lib.SourceModels() initial_core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice, dynamic_config_slice, geo, source_models + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + source_models, ) explicit_source_profiles = source_models_lib.build_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=initial_core_profiles, source_models=source_models, @@ -526,8 +546,8 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): evolving_names = tuple(['temp_ion']) coeffs = calc_coeffs.calc_coeffs( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=initial_core_profiles, transport_model=transport_model, @@ -593,11 +613,11 @@ def test_theta_residual_uses_updated_boundary_conditions(self): # Create a system with diffusive transport and no sources. When initialized # flat, residual should remain zero unless boundary conditions change. num_cells = 2 - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( nr=num_cells, el_heat_eq=False, ), @@ -606,7 +626,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self): predictor_corrector=False, theta_imp=0.0, ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) transport_model = constant_transport_model.ConstantTransportModel( runtime_params=constant_transport_model.RuntimeParams( chimin=0, @@ -624,28 +644,35 @@ def test_theta_residual_uses_updated_boundary_conditions(self): source_models.sources['ohmic_heat_source'].runtime_params.mode = ( source_runtime_params.Mode.ZERO ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - transport=transport_model.runtime_params, - sources=source_models.runtime_params, - stepper=stepper_params, - ) - static_config_slice_theta0 = config_slice.build_static_config_slice( - config, stepper=stepper_params - ) - static_config_slice_theta05 = dataclasses.replace( - static_config_slice_theta0, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + transport=transport_model.runtime_params, + sources=source_models.runtime_params, + stepper=stepper_params, + ) + ) + static_runtime_params_slice_theta0 = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params, stepper=stepper_params + ) + ) + static_runtime_params_slice_theta05 = dataclasses.replace( + static_runtime_params_slice_theta0, stepper=dataclasses.replace( - static_config_slice_theta0.stepper, theta_imp=0.5 + static_runtime_params_slice_theta0.stepper, theta_imp=0.5 ), ) source_models = source_models_lib.SourceModels() initial_core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice_theta0, dynamic_config_slice, geo, source_models + static_runtime_params_slice_theta0, + dynamic_runtime_params_slice, + geo, + source_models, ) explicit_source_profiles = source_models_lib.build_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=initial_core_profiles, source_models=source_models, @@ -656,8 +683,8 @@ def test_theta_residual_uses_updated_boundary_conditions(self): evolving_names = tuple(['temp_ion']) coeffs_old = calc_coeffs.calc_coeffs( - static_config_slice=static_config_slice_theta05, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice_theta05, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=initial_core_profiles, transport_model=transport_model, @@ -675,8 +702,8 @@ def test_theta_residual_uses_updated_boundary_conditions(self): right_face_constraint=initial_right_boundary, ) core_profiles_t_plus_dt = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice_theta0, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice_theta0, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) @@ -691,8 +718,8 @@ def test_theta_residual_uses_updated_boundary_conditions(self): # at all 0, and the residual should be 0. residual, _ = residual_and_loss.theta_method_block_residual( dt=dt, - static_config_slice=static_config_slice_theta05, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice_theta05, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice, geo=geo, x_old=(x_0,), x_new_guess_vec=x_0.value, @@ -711,8 +738,8 @@ def test_theta_residual_uses_updated_boundary_conditions(self): final_right_boundary = jnp.array(1.0) residual, _ = residual_and_loss.theta_method_block_residual( dt=dt, - static_config_slice=static_config_slice_theta0, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice_theta0, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice, geo=geo, x_old=(x_0,), x_new_guess_vec=x_0.value, @@ -732,8 +759,8 @@ def test_theta_residual_uses_updated_boundary_conditions(self): # But when theta_imp > 0, the residual should be non-zero. residual, _ = residual_and_loss.theta_method_block_residual( dt=dt, - static_config_slice=static_config_slice_theta05, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice_theta05, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice, geo=geo, x_old=(x_0,), core_profiles_t_plus_dt=dataclasses.replace( diff --git a/torax/geometry.py b/torax/geometry.py index 100d7175..adba9ca4 100644 --- a/torax/geometry.py +++ b/torax/geometry.py @@ -22,12 +22,12 @@ import chex import jax import jax.numpy as jnp -from torax import config as config_lib -from torax import config_slice from torax import constants from torax import geometry_loader from torax import jax_utils from torax import math_utils +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice @chex.dataclass(frozen=True) @@ -175,7 +175,7 @@ class CHEASEGeometry(Geometry): def build_circular_geometry( - config: config_lib.Config, + runtime_params: general_runtime_params.GeneralRuntimeParams, kappa: float = 1.72, Rmaj: float = 6.2, Rmin: float = 2.0, @@ -191,7 +191,7 @@ def build_circular_geometry( for this object, Fiddle-ify this function, not CircularGeometry.__init__. Args: - config: General TORAX config. + runtime_params: General TORAX runtime input parameters. kappa: Elogination. Defaults to 1.72 for the ITER elongation, to approximately correct volume and area integral Jacobians. Rmaj: major radius (R) in meters @@ -207,8 +207,8 @@ def build_circular_geometry( # r_norm coordinate is r/Rmin in circular, and rho_norm in standard # geometry (CHEASE/EQDSK) # Define mesh (Slab Uniform 1D with Jacobian = 1) - dr_norm = jnp.array(1) / config.numerics.nr - mesh = Grid1D.construct(nx=config.numerics.nr, dx=dr_norm) + dr_norm = jnp.array(1) / runtime_params.numerics.nr + mesh = Grid1D.construct(nx=runtime_params.numerics.nr, dx=dr_norm) rmax = jnp.array(Rmin) # helper variables for mesh cells and faces # r coordinate of faces @@ -253,16 +253,11 @@ def build_circular_geometry( delta_face = jnp.zeros(len(r_face)) # uses <1/R^2> with circular geometry - G2 = vpr / ( - 4 * jnp.pi**2 * Rmaj**2 * jnp.sqrt(1 - (r / Rmaj) ** 2) - ) + G2 = vpr / (4 * jnp.pi**2 * Rmaj**2 * jnp.sqrt(1 - (r / Rmaj) ** 2)) # generate G2_face by hand G2_outer_face = vpr_face[-1] / ( - 4 - * jnp.pi**2 - * Rmaj**2 - * jnp.sqrt(1 - (r_face[-1] / Rmaj) ** 2) + 4 * jnp.pi**2 * Rmaj**2 * jnp.sqrt(1 - (r_face[-1] / Rmaj) ** 2) ) G2_outer_face = jnp.expand_dims(G2_outer_face, 0) G2_face = jnp.concatenate( @@ -291,7 +286,7 @@ def build_circular_geometry( # High resolution versions for j (plasma current) and psi (poloidal flux) # manipulations. Needed if psi is initialized from plasma current, which is # the only option for ad-hoc circular geometry. - r_hires_norm = jnp.linspace(0, 1, config.numerics.nr * hires_fac) + r_hires_norm = jnp.linspace(0, 1, runtime_params.numerics.nr * hires_fac) r_hires = r_hires_norm * rmax Rout = Rmaj + r @@ -318,12 +313,7 @@ def build_circular_geometry( ) # uses <1/R^2> with circular geometry - denom = ( - 4 - * jnp.pi**2 - * Rmaj**2 - * jnp.sqrt(1 - (r_hires / Rmaj) ** 2) - ) + denom = 4 * jnp.pi**2 * Rmaj**2 * jnp.sqrt(1 - (r_hires / Rmaj) ** 2) G2_hires = vpr_hires / denom # terms applied in transport equations and dt calculation. @@ -400,7 +390,7 @@ def build_circular_geometry( def build_chease_geometry( - config: config_lib.Config, + runtime_params: general_runtime_params.GeneralRuntimeParams, geometry_dir: str | None = None, geometry_file: str = 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', Rmaj: float = 6.2, @@ -418,7 +408,7 @@ def build_chease_geometry( for this object, Fiddle-ify this function, not CHEASEGeometry.__init__. Args: - config: General TORAX config. + runtime_params: General TORAX runtime input parameters. geometry_dir: Directory where to find the CHEASE file describing the magnetic geometry. If None, uses the environment variable TORAX_GEOMETRY_DIR if available. If that variable is not set and @@ -450,8 +440,10 @@ def build_chease_geometry( ) # TODO(b/326406367): incorporate time dependent geometry - # build t_initial config_slice - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) + # build t_initial runtime_params_slice + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice(runtime_params) + ) # Prepare variables from CHEASE to be interpolated into our simulation # grid. CHEASE variables are normalized. Need to unnormalize them with @@ -461,9 +453,7 @@ def build_chease_geometry( B0 = jnp.array(B0) psiunnormfactor = (Rmaj**2 * B0) * 2 * jnp.pi psi_chease = chease_data['PSIchease=psi/2pi'] * psiunnormfactor - Ip_chease = ( - chease_data['Ipprofile'] / constants.CONSTANTS.mu0 * Rmaj * B0 - ) + Ip_chease = chease_data['Ipprofile'] / constants.CONSTANTS.mu0 * Rmaj * B0 # toroidal flux coordinate rho = chease_data['RHO_TOR=sqrt(Phi/pi/B0)'] * Rmaj @@ -524,16 +514,16 @@ def build_chease_geometry( # if Ip from parameter file, renormalize psi to match desired current if Ip_from_parameters: Ip_scale_factor = ( - dynamic_config_slice.profile_conditions.Ip * 1e6 / Ip_chease[-1] + dynamic_runtime_params_slice.profile_conditions.Ip * 1e6 / Ip_chease[-1] ) psi_from_chease_Ip *= Ip_scale_factor else: - # This overwrites the config.profile_conditions.Ip, even if it's time - # dependent, to be consistent with the geometry file being processed. + # This overwrites the runtime_params.profile_conditions.Ip, even if it's + # time dependent, to be consistent with the geometry file being processed. # TODO(b/326406367): Do not rely on writing back to the config to # make this work. We should not rely on the geometry being computed for the # config to have the correct Ip. - config.profile_conditions.Ip = Ip_chease[-1] / 1e6 + runtime_params.profile_conditions.Ip = Ip_chease[-1] / 1e6 Ip_scale_factor = 1 # volume, area, and dV/drho, dS/drho @@ -556,9 +546,9 @@ def build_chease_geometry( # fill geometry structure # r_norm coordinate is rho_tor_norm - dr_norm = jnp.array(1) / config.numerics.nr + dr_norm = jnp.array(1) / runtime_params.numerics.nr # normalized grid - mesh = Grid1D.construct(nx=config.numerics.nr, dx=dr_norm) + mesh = Grid1D.construct(nx=runtime_params.numerics.nr, dx=dr_norm) rmax = rho[-1] # radius denormalization constant # helper variables for mesh cells and faces r_face_norm = mesh.face_centers @@ -570,7 +560,7 @@ def build_chease_geometry( # High resolution versions for j (plasma current) and psi (poloidal flux) # manipulations. Needed if psi is initialized from plasma current. - r_hires_norm = jnp.linspace(0, 1, config.numerics.nr * hires_fac) + r_hires_norm = jnp.linspace(0, 1, runtime_params.numerics.nr * hires_fac) r_hires = r_hires_norm * rmax interp_func = lambda x: jnp.interp(x, rhon, vpr_chease) diff --git a/torax/interpolated_param.py b/torax/interpolated_param.py index 1e30da8a..16912190 100644 --- a/torax/interpolated_param.py +++ b/torax/interpolated_param.py @@ -166,11 +166,12 @@ def _convert_value_to_floats( class InterpolatedParam(InterpolatedParamBase): """Parameter that may vary based on an input coordinate. - This class is useful for defining time-dependent config parameters, but can + This class is useful for defining time-dependent runtime parameters, but can be used to define any parameters that vary across some range. This class is the main "user-facing" class defined in this module. - See `config.Config` and associated tests to see how this is used. + See `config.runtime_params.RuntimeParams` and associated tests to see how this + is used. """ def __init__( @@ -222,7 +223,7 @@ def is_bool_param(self) -> bool: return self._is_bool_param -# In Config, users should be able to either specify the InterpolatedParam object -# directly or the values that go in the constructor. This helps with brevity -# since a lot of these params are fixed floats. +# In runtime_params, users should be able to either specify the +# InterpolatedParam object directly or the values that go in the constructor. +# This helps with brevity since a lot of these params are fixed floats. InterpParamOrInterpParamInput = InterpolatedParam | InterpolatedParamInput diff --git a/torax/sim.py b/torax/sim.py index 7949b1a8..02dddc60 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -34,14 +34,15 @@ import jax import jax.numpy as jnp from torax import calc_coeffs -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import fvm from torax import geometry from torax import jax_utils from torax import physics from torax import state +from torax.config import config_args +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib from torax.spectators import spectator as spectator_lib @@ -66,7 +67,7 @@ class CoeffsCallback: """Implements fvm.Block1DCoeffsCallback using calc_coeffs. Attributes: - static_config_slice: See the docstring for `stepper.Stepper`. + static_runtime_params_slice: See the docstring for `stepper.Stepper`. geo: See the docstring for `stepper.Stepper`. core_profiles_t: The core plasma profiles at the start of the time step. core_profiles_t_plus_dt: Core plasma profiles at the end of the time step. @@ -78,7 +79,7 @@ class CoeffsCallback: def __init__( self, - static_config_slice: config_slice.StaticConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -87,7 +88,7 @@ def __init__( source_models: source_models_lib.SourceModels, evolving_names: tuple[str, ...], ): - self.static_config_slice = static_config_slice + self.static_runtime_params_slice = static_runtime_params_slice self.geo = geo self.core_profiles_t = core_profiles_t self.core_profiles_t_plus_dt = core_profiles_t_plus_dt @@ -98,7 +99,7 @@ def __init__( def __call__( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, x: tuple[fvm.CellVariable, ...], allow_pereverzev: bool = False, # Checks if reduced calc_coeffs for explicit terms when theta_imp=1 @@ -107,11 +108,11 @@ def __call__( ): replace = {k: v for k, v in zip(self.evolving_names, x)} if explicit_call: - core_profiles = config_lib.recursive_replace( + core_profiles = config_args.recursive_replace( self.core_profiles_t, **replace ) else: - core_profiles = config_lib.recursive_replace( + core_profiles = config_args.recursive_replace( self.core_profiles_t_plus_dt, **replace ) # update ion density in core_profiles if ne is being evolved. @@ -121,20 +122,20 @@ def __call__( core_profiles.ni, value=core_profiles.ne.value * physics.get_main_ion_dilution_factor( - dynamic_config_slice.plasma_composition.Zimp, - dynamic_config_slice.plasma_composition.Zeff, + dynamic_runtime_params_slice.plasma_composition.Zimp, + dynamic_runtime_params_slice.plasma_composition.Zeff, ), ) core_profiles = dataclasses.replace(core_profiles, ni=ni) if allow_pereverzev: - use_pereverzev = self.static_config_slice.stepper.use_pereverzev + use_pereverzev = self.static_runtime_params_slice.stepper.use_pereverzev else: use_pereverzev = False return calc_coeffs.calc_coeffs( - static_config_slice=self.static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=self.static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=self.geo, core_profiles=core_profiles, transport_model=self.transport_model, @@ -154,17 +155,24 @@ class FrozenCoeffsCallback(CoeffsCallback): """ def __init__(self, *args, **kwargs): - if 'dynamic_config_slice' not in kwargs: - raise ValueError('dynamic_config_slice must be provided.') - dynamic_config_slice = kwargs.pop('dynamic_config_slice') + if 'dynamic_runtime_params_slice' not in kwargs: + raise ValueError('dynamic_runtime_params_slice must be provided.') + dynamic_runtime_params_slice = kwargs.pop('dynamic_runtime_params_slice') super().__init__(*args, **kwargs) x = tuple([self.core_profiles_t[name] for name in self.evolving_names]) self.frozen_coeffs = super().__call__( - dynamic_config_slice, x, allow_pereverzev=False, explicit_call=False + dynamic_runtime_params_slice, + x, + allow_pereverzev=False, + explicit_call=False, ) def __call__( - self, dynamic_config_slice, x, allow_pereverzev=False, explicit_call=False + self, + dynamic_runtime_params_slice, + x, + allow_pereverzev=False, + explicit_call=False, ): return self.frozen_coeffs @@ -217,8 +225,8 @@ def transport_model(self) -> transport_model_lib.TransportModel: def __call__( self, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, geo: geometry.Geometry, input_state: state.ToraxSimState, explicit_source_profiles: source_profiles_lib.SourceProfiles, @@ -226,12 +234,12 @@ def __call__( """Advances the simulation state one time step. Args: - static_config_slice: Static parameters that, if they change, should - trigger a recompilation of the SimulationStepFn. - dynamic_config_slice_provider: Object that returns a set of runtime - parameters which may change from time step to time step or simulation - run to run. If these config parameters change, it does NOT trigger a JAX - recompilation. + static_runtime_params_slice: Static parameters that, if they change, + should trigger a recompilation of the SimulationStepFn. + dynamic_runtime_params_slice_provider: Object that returns a set of + runtime parameters which may change from time step to time step or + simulation run to run. If these runtime parameters change, it does NOT + trigger a JAX recompilation. geo: The geometry of the torus during this time step of the simulation. While the geometry may change, any changes to the grid size can trigger recompilation of the stepper (if it is jitted) or an error (assuming it @@ -252,7 +260,9 @@ def __call__( 2 if solver converged within coarse tolerance. Allowed to pass with a warning. Occasional error=2 has low impact on final sim state. """ - dynamic_config_slice_t = dynamic_config_slice_provider(input_state.t) + dynamic_runtime_params_slice_t = dynamic_runtime_params_slice_provider( + input_state.t + ) # TODO(b/335598388): We call the transport model both here and in the the # Stepper / CoeffsCallback. This isn't a problem *so long as all of those # calls fall within the same jit scope* because can use @@ -261,12 +271,12 @@ def __call__( # calculate transport coeffs at delta_t = 0 in only one place, so that we # have some flexibility in where to place the jit boundaries. transport_coeffs = self._jitted_transport_model( - dynamic_config_slice_t, geo, input_state.core_profiles + dynamic_runtime_params_slice_t, geo, input_state.core_profiles ) # initialize new dt and reset stepper iterations. dt, time_step_calculator_state = self._time_step_calculator.next_dt( - dynamic_config_slice_t, + dynamic_runtime_params_slice_t, geo, input_state.core_profiles, input_state.time_step_calculator_state, @@ -274,25 +284,28 @@ def __call__( ) crosses_t_final = ( - input_state.t < dynamic_config_slice_t.numerics.t_final + input_state.t < dynamic_runtime_params_slice_t.numerics.t_final ) * ( - input_state.t + input_state.dt > dynamic_config_slice_t.numerics.t_final + input_state.t + input_state.dt + > dynamic_runtime_params_slice_t.numerics.t_final ) dt = jnp.where( jnp.logical_and( - dynamic_config_slice_t.numerics.exact_t_final, + dynamic_runtime_params_slice_t.numerics.exact_t_final, crosses_t_final, ), - dynamic_config_slice_t.numerics.t_final - input_state.t, + dynamic_runtime_params_slice_t.numerics.t_final - input_state.t, dt, ) if jnp.any(jnp.isnan(dt)): raise ValueError('dt is NaN.') - # The stepper needs the dynamic_config_slice at time t + dt for implicit - # computations in the solver. - dynamic_config_slice_t_plus_dt = dynamic_config_slice_provider( - input_state.t + dt, + # The stepper needs the dynamic_runtime_params_slice at time t + dt for + # implicit computations in the solver. + dynamic_runtime_params_slice_t_plus_dt = ( + dynamic_runtime_params_slice_provider( + input_state.t + dt, + ) ) core_profiles_t = input_state.core_profiles @@ -301,8 +314,8 @@ def __call__( # conditions and time-dependent prescribed profiles not directly solved by # PDE system. core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles_t=core_profiles_t, ) @@ -314,9 +327,9 @@ def __call__( core_profiles, core_sources, core_transport, stepper_error_state = ( self._stepper_fn( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t=dynamic_config_slice_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -336,7 +349,7 @@ def __call__( stepper_error_state=stepper_error_state, ) - if static_config_slice.adaptive_dt: + if static_runtime_params_slice.adaptive_dt: # Check if stepper converged. If not, proceed to body_fun def cond_fun(updated_output: state.ToraxSimState) -> bool: if updated_output.stepper_error_state == 1: @@ -354,28 +367,30 @@ def body_fun( dt = ( updated_output.dt - / dynamic_config_slice_t.numerics.dt_reduction_factor + / dynamic_runtime_params_slice_t.numerics.dt_reduction_factor ) if jnp.any(jnp.isnan(dt)): raise ValueError('dt is NaN.') - if dt < dynamic_config_slice_t.numerics.mindt: + if dt < dynamic_runtime_params_slice_t.numerics.mindt: raise ValueError('dt below minimum timestep following adaptation') - dynamic_config_slice_t_plus_dt = dynamic_config_slice_provider( - input_state.t + dt, + dynamic_runtime_params_slice_t_plus_dt = ( + dynamic_runtime_params_slice_provider( + input_state.t + dt, + ) ) core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( core_profiles_t=core_profiles_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, geo=geo, ) core_profiles, core_sources, core_transport, stepper_error_state = ( self._stepper_fn( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t=dynamic_config_slice_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -396,10 +411,12 @@ def body_fun( output_state = jax_utils.py_while(cond_fun, body_fun, output_state) # Update total current, q, and s profiles based on new psi - dynamic_config_slice_t_plus_dt = dynamic_config_slice_provider( - input_state.t + output_state.dt, + dynamic_runtime_params_slice_t_plus_dt = ( + dynamic_runtime_params_slice_provider( + input_state.t + output_state.dt, + ) ) - q_corr = dynamic_config_slice_t_plus_dt.numerics.q_correction_factor + q_corr = dynamic_runtime_params_slice_t_plus_dt.numerics.q_correction_factor output_state.core_profiles = physics.update_jtot_q_face_s_face( geo=geo, core_profiles=output_state.core_profiles, @@ -409,7 +426,7 @@ def body_fun( # Update ohmic and bootstrap current based on the new core profiles. output_state.core_profiles = update_current_distribution( source_models=self._stepper_fn.source_models, - dynamic_config_slice=dynamic_config_slice_t_plus_dt, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles=output_state.core_profiles, ) @@ -417,7 +434,7 @@ def body_fun( # Update psidot based on the new core profiles output_state.core_profiles = update_psidot( source_models=self._stepper_fn.source_models, - dynamic_config_slice=dynamic_config_slice_t_plus_dt, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles=output_state.core_profiles, ) @@ -426,18 +443,21 @@ def body_fun( def get_initial_state( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_models: source_models_lib.SourceModels, time_step_calculator: ts.TimeStepCalculator, ) -> state.ToraxSimState: """Returns the initial state to be used by run_simulation().""" initial_core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice, dynamic_config_slice, geo, source_models + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + source_models, ) return state.ToraxSimState( - t=jnp.array(dynamic_config_slice.numerics.t_initial), + t=jnp.array(dynamic_runtime_params_slice.numerics.t_initial), dt=jnp.zeros(()), core_profiles=initial_core_profiles, # This will be overridden within run_simulation(). @@ -534,8 +554,8 @@ class Sim: def __init__( self, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, geometry_provider: GeometryProvider, initial_state: state.ToraxSimState, time_step_calculator: ts.TimeStepCalculator, @@ -543,8 +563,10 @@ def __init__( stepper: stepper_lib.Stepper | None = None, step_fn: SimulationStepFn | None = None, ): - self._static_config_slice = static_config_slice - self._dynamic_config_slice_provider = dynamic_config_slice_provider + self._static_runtime_params_slice = static_runtime_params_slice + self._dynamic_runtime_params_slice_provider = ( + dynamic_runtime_params_slice_provider + ) self._geometry_provider = geometry_provider self._initial_state = initial_state self._time_step_calculator = time_step_calculator @@ -586,14 +608,16 @@ def geometry_provider(self) -> GeometryProvider: return self._geometry_provider @property - def dynamic_config_slice_provider( + def dynamic_runtime_params_slice_provider( self, - ) -> config_slice.DynamicConfigSliceProvider: - return self._dynamic_config_slice_provider + ) -> runtime_params_slice.DynamicRuntimeParamsSliceProvider: + return self._dynamic_runtime_params_slice_provider @property - def static_config_slice(self) -> config_slice.StaticConfigSlice: - return self._static_config_slice + def static_runtime_params_slice( + self, + ) -> runtime_params_slice.StaticRuntimeParamsSlice: + return self._static_runtime_params_slice @property def step_fn(self) -> SimulationStepFn | None: @@ -647,8 +671,8 @@ def run( if spectator is not None: spectator.reset() return run_simulation( - static_config_slice=self.static_config_slice, - dynamic_config_slice_provider=self.dynamic_config_slice_provider, + static_runtime_params_slice=self.static_runtime_params_slice, + dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider, geometry_provider=self.geometry_provider, initial_state=self.initial_state, time_step_calculator=self.time_step_calculator, @@ -659,22 +683,22 @@ def run( def build_sim_from_config( - config: config_lib.Config, + runtime_params: general_runtime_params.GeneralRuntimeParams, geo: geometry.Geometry, stepper_builder: stepper_lib.StepperBuilder, transport_model: transport_model_lib.TransportModel, source_models: source_models_lib.SourceModels, time_step_calculator: Optional[ts.TimeStepCalculator] = None, ) -> Sim: - """Builds a Sim object from a Config file. + """Builds a Sim object from the input runtime params and objects. Over time we expect to transition to functions that just build Sim objects directly. This function is needed during the transitional stage during which many objects still require - a Config. + a config. Args: - config: The Config used to build everything. + runtime_params: The input runtime params used throughout the simulation run. geo: Describes the magnetic geometry. stepper_builder: A callable to build the stepper. The stepper has already been factored out of the config. @@ -688,36 +712,40 @@ def build_sim_from_config( sim: The built Sim instance. """ - static_config_slice = config_slice.build_static_config_slice( - config, - stepper=stepper_builder.runtime_params, + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params=runtime_params, + stepper=stepper_builder.runtime_params, + ) ) - dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( - config=config, - transport_getter=lambda: transport_model.runtime_params, - sources_getter=lambda: source_models.runtime_params, - stepper_getter=lambda: stepper_builder.runtime_params, + dynamic_runtime_params_slice_provider = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport_getter=lambda: transport_model.runtime_params, + sources_getter=lambda: source_models.runtime_params, + stepper_getter=lambda: stepper_builder.runtime_params, + ) ) stepper = stepper_builder(transport_model, source_models) if time_step_calculator is None: time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() - # build dynamic_config_slice at t_initial for initial conditions - dynamic_config_slice = dynamic_config_slice_provider( - config.numerics.t_initial + # build dynamic_runtime_params_slice at t_initial for initial conditions + dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( + runtime_params.numerics.t_initial ) initial_state = get_initial_state( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=stepper.source_models, time_step_calculator=time_step_calculator, ) return Sim( - static_config_slice=static_config_slice, - dynamic_config_slice_provider=dynamic_config_slice_provider, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, geometry_provider=ConstantGeometryProvider(geo), initial_state=initial_state, time_step_calculator=time_step_calculator, @@ -727,8 +755,8 @@ def build_sim_from_config( def run_simulation( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, geometry_provider: GeometryProvider, initial_state: state.ToraxSimState, time_step_calculator: ts.TimeStepCalculator, @@ -753,16 +781,18 @@ def run_simulation( history. Args: - static_config_slice: A static set of arguments to provide to the step_fn. If - step_fn is JAX-compiled, then these params are "compile-time constant" - meaning that they are considered static to the compiled function. If they - change (i.e. the same step_fn is called again with a different - static_config_slice), then the step_fn will be recompiled. JAX determines - if recompilation is necessary via the hash of the static_config_slice. - dynamic_config_slice_provider: Provides a DynamicConfigSlice to use as input - for each time step. See static_config_slice and the config_slice module - docstring for config_slice to understand why we need the dynamic and - static config slices and what they control. + static_runtime_params_slice: A static set of arguments to provide to the + step_fn. If step_fn is JAX-compiled, then these params are "compile-time + constant" meaning that they are considered static to the compiled + function. If they change (i.e. the same step_fn is called again with a + different static_runtime_params_slice), then the step_fn will be + recompiled. JAX determines if recompilation is necessary via the hash of + the static_runtime_params_slice. + dynamic_runtime_params_slice_provider: Provides a DynamicRuntimeParamsSlice + to use as input for each time step. See static_runtime_params_slice and + the runtime_params_slice module docstring for runtime_params_slice to + understand why we need the dynamic and static config slices and what they + control. geometry_provider: Provides the geometry of the torus for each time step based on the ToraxSimState at the start of the time step. The geometry may change from time step to time step, so the sim needs a function to provide @@ -809,15 +839,17 @@ def run_simulation( initial_state, ] stepper_error_state = 0 - dynamic_config_slice = dynamic_config_slice_provider(initial_state.t) + dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( + initial_state.t + ) geo = geometry_provider(initial_state) # Populate the starting state with source profiles from the implicit sources # before starting the run-loop. The explicit source profiles will be computed # inside the loop and will be merged with these implicit source profiles. initial_state.core_sources = _get_initial_source_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=initial_state.core_profiles, source_models=step_fn.stepper.source_models, @@ -832,7 +864,7 @@ def run_simulation( # done. while time_step_calculator.not_done( sim_state.t, - dynamic_config_slice, + dynamic_runtime_params_slice, sim_state.time_step_calculator_state, ): # Measure how long in wall clock time each simulation step takes. @@ -840,7 +872,7 @@ def run_simulation( if log_timestep_info: _log_timestep(sim_state.t, sim_state.dt, sim_state.stepper_iterations) # TODO(b/330172917): once tol and coarse_tol are configurable in the - # config, also log the value of tol and coarse_tol below + # runtime_params, also log the value of tol and coarse_tol below match stepper_error_state: case 0: pass @@ -855,7 +887,7 @@ def run_simulation( # DynamicSourceConfigSlice. All implicit sources will have their profiles # set to 0. explicit_source_profiles = source_models_lib.build_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=sim_state.core_profiles, source_models=step_fn.stepper.source_models, @@ -881,15 +913,17 @@ def run_simulation( # Now prep the spectator for the following time step. spectator.before_step() sim_state = step_fn( - static_config_slice, - dynamic_config_slice_provider, + static_runtime_params_slice, + dynamic_runtime_params_slice_provider, geo, sim_state, explicit_source_profiles, ) stepper_error_state = sim_state.stepper_error_state # Update the runtime config for the next iteration. - dynamic_config_slice = dynamic_config_slice_provider(sim_state.t) + dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( + sim_state.t + ) torax_outputs.append(sim_state) geo = geometry_provider(sim_state) wall_clock_step_times.append(time.time() - step_start_time) @@ -902,7 +936,7 @@ def run_simulation( # profiles computed based on the final state. logging.info("Updating last step's source profiles.") explicit_source_profiles = source_models_lib.build_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=sim_state.core_profiles, source_models=step_fn.stepper.source_models, @@ -1015,7 +1049,7 @@ def _update_spectator( def update_current_distribution( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: source_models_lib.SourceModels, @@ -1023,8 +1057,8 @@ def update_current_distribution( """Update bootstrap current based on the new core_profiles.""" bootstrap_profile = source_models.j_bootstrap.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources[ + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ source_models.j_bootstrap_name ], geo=geo, @@ -1058,7 +1092,7 @@ def update_current_distribution( def update_psidot( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: source_models_lib.SourceModels, @@ -1068,7 +1102,7 @@ def update_psidot( psidot = dataclasses.replace( core_profiles.psidot, value=source_models_lib.calc_psidot( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, core_profiles, source_models, @@ -1083,21 +1117,21 @@ def update_psidot( def provide_core_profiles_t_plus_dt( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, ) -> state.CoreProfiles: """Provides state at t_plus_dt with new boundary conditions and prescribed profiles.""" updated_boundary_conditions = ( core_profile_setters.compute_boundary_conditions( - dynamic_config_slice_t_plus_dt, + dynamic_runtime_params_slice_t_plus_dt, geo, ) ) updated_values = core_profile_setters.updated_prescribed_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles=core_profiles_t, ) @@ -1123,8 +1157,8 @@ def provide_core_profiles_t_plus_dt( core_profiles_t.ni, value=ne.value * physics.get_main_ion_dilution_factor( - dynamic_config_slice_t_plus_dt.plasma_composition.Zimp, - dynamic_config_slice_t_plus_dt.plasma_composition.Zeff, + dynamic_runtime_params_slice_t_plus_dt.plasma_composition.Zimp, + dynamic_runtime_params_slice_t_plus_dt.plasma_composition.Zeff, ), ) core_profiles_t_plus_dt = dataclasses.replace( @@ -1134,8 +1168,8 @@ def provide_core_profiles_t_plus_dt( def _get_initial_source_profiles( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: source_models_lib.SourceModels, @@ -1154,10 +1188,11 @@ def _get_initial_source_profiles( core profiles. Args: - static_config_slice: Config parameters which, when they change, trigger - recompilations. They should not change within a single run of the sim. - dynamic_config_slice: Runtime parameters which may change from time step to - time step without triggering recompilations. + static_runtime_params_slice: Runtime parameters which, when they change, + trigger recompilations. They should not change within a single run of the + sim. + dynamic_runtime_params_slice: Runtime parameters which may change from time + step to time step without triggering recompilations. geo: The geometry of the torus during this time step of the simulation. core_profiles: Core profiles that may evolve throughout the course of a simulation. These values here are, of course, only the original states. @@ -1168,16 +1203,16 @@ def _get_initial_source_profiles( the starting state. """ implicit_profiles = source_models_lib.build_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, source_models=source_models, explicit=False, ) qei = source_models.qei_source.get_qei( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources[ + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ source_models.qei_source_name ], geo=geo, diff --git a/torax/simulation_app.py b/torax/simulation_app.py index 518a3e75..6a11f99f 100644 --- a/torax/simulation_app.py +++ b/torax/simulation_app.py @@ -52,10 +52,10 @@ def get_sim(): from jax import numpy as jnp from matplotlib import pyplot as plt import torax -from torax import config_slice from torax import geometry from torax import sim as sim_lib from torax import state as state_lib +from torax.config import runtime_params_slice from torax.sources import runtime_params as source_runtime_params_lib from torax.spectators import plotting from torax.stepper import runtime_params as stepper_runtime_params_lib @@ -204,7 +204,7 @@ def _get_output_dir( def update_sim( sim: sim_lib.Sim, - config: torax.Config, + runtime_params: torax.GeneralRuntimeParams, geo: geometry.Geometry, transport_runtime_params: transport_runtime_params_lib.RuntimeParams, source_runtime_params: dict[str, source_runtime_params_lib.RuntimeParams], @@ -212,7 +212,7 @@ def update_sim( [], stepper_runtime_params_lib.RuntimeParams ], ) -> sim_lib.Sim: - """Updates the sim with a new config and geometry.""" + """Updates the sim with a new set of runtime params and geometry.""" # NOTE: This function will NOT update any of the following: # - stepper (for the mesh state) # - transport model object (runtime params are updated) @@ -221,21 +221,25 @@ def update_sim( # - source objects (runtime params are updated) sim.transport_model.runtime_params = transport_runtime_params _update_source_params(sim, source_runtime_params) - static_config_slice = config_slice.build_static_config_slice( - config, - stepper=stepper_runtime_params_getter(), + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice( + runtime_params, + stepper=stepper_runtime_params_getter(), + ) ) - dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( - config=config, - transport_getter=lambda: sim.transport_model.runtime_params, - sources_getter=lambda: sim.source_models.runtime_params, - stepper_getter=stepper_runtime_params_getter, + dynamic_runtime_params_slice_provider = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport_getter=lambda: sim.transport_model.runtime_params, + sources_getter=lambda: sim.source_models.runtime_params, + stepper_getter=stepper_runtime_params_getter, + ) ) initial_state = sim_lib.get_initial_state( - dynamic_config_slice=dynamic_config_slice_provider( - t=config.numerics.t_initial + dynamic_runtime_params_slice=dynamic_runtime_params_slice_provider( + t=runtime_params.numerics.t_initial ), - static_config_slice=static_config_slice, + static_runtime_params_slice=static_runtime_params_slice, geo=geo, time_step_calculator=sim.time_step_calculator, source_models=sim.source_models, @@ -244,8 +248,8 @@ def update_sim( time_step_calculator=sim.time_step_calculator, initial_state=initial_state, geometry_provider=sim_lib.ConstantGeometryProvider(geo), - dynamic_config_slice_provider=dynamic_config_slice_provider, - static_config_slice=static_config_slice, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + static_runtime_params_slice=static_runtime_params_slice, step_fn=sim.step_fn, ) diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index c4917e68..6928dd83 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -21,14 +21,14 @@ import chex from jax import numpy as jnp from jax.scipy import integrate -from torax import config_slice from torax import constants from torax import geometry from torax import jax_utils from torax import physics from torax import state +from torax.config import config_args +from torax.config import runtime_params_slice from torax.fvm import cell_variable -from torax.runtime_params import config_slice_args from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_profiles @@ -41,7 +41,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -95,7 +95,7 @@ class BootstrapCurrentSource(source.Source): def get_value( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, @@ -134,7 +134,7 @@ def get_value( psi = psi or core_profiles.psi # pytype: enable=attribute-error return calc_neoclassical( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, temp_ion=temp_ion, @@ -160,7 +160,7 @@ def get_source_profile_for_affected_core_profile( @jax_utils.jit def calc_neoclassical( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: DynamicRuntimeParams, geo: geometry.Geometry, temp_ion: cell_variable.CellVariable, @@ -173,7 +173,7 @@ def calc_neoclassical( """Calculates sigmaneo, j_bootstrap, and I_bootstrap. Args: - dynamic_config_slice: General configuration parameters. + dynamic_runtime_params_slice: General configuration parameters. dynamic_source_runtime_params: Source-specific runtime parameters. geo: Torus geometry. temp_ion: Ion temperature. We don't pass in a full `State` here because this @@ -195,9 +195,9 @@ def calc_neoclassical( # Formulas from Sauter PoP 1999. Future work can include Redl PoP 2021 # corrections. - true_ne_face = ne.face_value() * dynamic_config_slice.numerics.nref - true_ni_face = ni.face_value() * dynamic_config_slice.numerics.nref - Zeff = dynamic_config_slice.plasma_composition.Zeff + true_ne_face = ne.face_value() * dynamic_runtime_params_slice.numerics.nref + true_ni_face = ni.face_value() * dynamic_runtime_params_slice.numerics.nref + Zeff = dynamic_runtime_params_slice.plasma_composition.Zeff # # local r/R0 on face grid epsilon = (geo.Rout_face - geo.Rin_face) / (geo.Rout_face + geo.Rin_face) @@ -214,7 +214,7 @@ def calc_neoclassical( lnLame = 31.3 - jnp.log(jnp.sqrt(true_ne_face) / (temp_el.face_value() * 1e3)) # TODO(b/335599537) use ni instead of ne lnLami = 30 - jnp.log( - dynamic_config_slice.plasma_composition.Zi**3 + dynamic_runtime_params_slice.plasma_composition.Zi**3 * jnp.sqrt(true_ne_face) / ((temp_ion.face_value() * 1e3) ** 1.5) ) @@ -227,7 +227,7 @@ def calc_neoclassical( geo=geo, psi=psi, jtot_face=jtot_face, - q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, ) nuestar = ( 6.921e-18 diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index 9b00c40d..571acadf 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -20,10 +20,10 @@ import chex from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import state -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params_slice from torax.sources import formulas from torax.sources import runtime_params as runtime_params_lib from torax.sources import source @@ -44,7 +44,7 @@ def build_dynamic_params( t: chex.Numeric, ) -> DynamicGasPuffRuntimeParams: return DynamicGasPuffRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicGasPuffRuntimeParams, t=t, @@ -60,7 +60,7 @@ class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams): # Default formula: exponential with nref normalization. def _calc_puff_source( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, @@ -72,7 +72,7 @@ def _calc_puff_source( c2=dynamic_source_runtime_params.puff_decay_length, total=( dynamic_source_runtime_params.S_puff_tot - / dynamic_config_slice.numerics.nref + / dynamic_runtime_params_slice.numerics.nref ), use_normalized_r=True, geo=geo, @@ -106,7 +106,7 @@ def build_dynamic_params( t: chex.Numeric, ) -> DynamicNBIParticleRuntimeParams: return DynamicNBIParticleRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicNBIParticleRuntimeParams, t=t, @@ -122,7 +122,7 @@ class DynamicNBIParticleRuntimeParams(runtime_params_lib.DynamicRuntimeParams): def _calc_nbi_source( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, @@ -136,7 +136,7 @@ def _calc_nbi_source( c2=dynamic_source_runtime_params.nbi_particle_width, total=( dynamic_source_runtime_params.S_nbi_tot - / dynamic_config_slice.numerics.nref + / dynamic_runtime_params_slice.numerics.nref ), use_normalized_r=True, geo=geo, @@ -172,7 +172,7 @@ def build_dynamic_params( t: chex.Numeric, ) -> DynamicPelletRuntimeParams: return DynamicPelletRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicPelletRuntimeParams, t=t, @@ -188,7 +188,7 @@ class DynamicPelletRuntimeParams(runtime_params_lib.DynamicRuntimeParams): def _calc_pellet_source( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, @@ -200,7 +200,7 @@ def _calc_pellet_source( c2=dynamic_source_runtime_params.pellet_width, total=( dynamic_source_runtime_params.S_pellet_tot - / dynamic_config_slice.numerics.nref + / dynamic_runtime_params_slice.numerics.nref ), use_normalized_r=True, geo=geo, diff --git a/torax/sources/external_current_source.py b/torax/sources/external_current_source.py index d5ed5a83..512b8b69 100644 --- a/torax/sources/external_current_source.py +++ b/torax/sources/external_current_source.py @@ -21,11 +21,11 @@ import chex from jax import numpy as jnp from jax.scipy import integrate -from torax import config_slice from torax import geometry from torax import jax_utils from torax import state -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib from torax.sources import source @@ -54,7 +54,7 @@ def build_dynamic_params( t: chex.Numeric, ) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -87,7 +87,7 @@ def __post_init__(self): def _calculate_jext_face( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, @@ -95,7 +95,7 @@ def _calculate_jext_face( """Calculates the external current density profiles. Args: - dynamic_config_slice: Parameter configuration at present timestep. + dynamic_runtime_params_slice: Parameter configuration at present timestep. dynamic_source_runtime_params: Source-specific parameters at the present timestep. geo: Tokamak geometry. @@ -107,7 +107,7 @@ def _calculate_jext_face( """ assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) Iext = _calculate_Iext( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, ) # form of external current on face grid @@ -123,7 +123,7 @@ def _calculate_jext_face( def _calculate_jext_hires( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, @@ -131,7 +131,7 @@ def _calculate_jext_hires( """Calculates the external current density profile along the hires grid. Args: - dynamic_config_slice: Parameter configuration at present timestep. + dynamic_runtime_params_slice: Parameter configuration at present timestep. dynamic_source_runtime_params: Source-specific parameters at the present timestep. geo: Tokamak geometry. @@ -143,7 +143,7 @@ def _calculate_jext_hires( """ assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) Iext = _calculate_Iext( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, ) # calculate "External" current profile (e.g. ECCD) @@ -159,7 +159,7 @@ def _calculate_jext_hires( def _calculate_Iext( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: DynamicRuntimeParams, ) -> chex.Numeric: """Calculates the total value of external current.""" @@ -167,7 +167,7 @@ def _calculate_Iext( dynamic_source_runtime_params.use_absolute_jext, dynamic_source_runtime_params.Iext, ( - dynamic_config_slice.profile_conditions.Ip + dynamic_runtime_params_slice.profile_conditions.Ip * dynamic_source_runtime_params.fext ), ) @@ -200,7 +200,7 @@ class ExternalCurrentSource(source.Source): def get_value( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, @@ -209,7 +209,7 @@ def get_value( assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) self.check_mode(dynamic_source_runtime_params.mode) profile = source.get_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, @@ -224,7 +224,7 @@ def get_value( def jext_hires( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, ) -> jnp.ndarray: @@ -232,7 +232,7 @@ def jext_hires( assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) self.check_mode(dynamic_source_runtime_params.mode) return source.get_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=None, diff --git a/torax/sources/formula_config.py b/torax/sources/formula_config.py index 140a552c..3f07c9a7 100644 --- a/torax/sources/formula_config.py +++ b/torax/sources/formula_config.py @@ -20,7 +20,7 @@ import chex from torax import interpolated_param -from torax.runtime_params import config_slice_args +from torax.config import config_args # Type-alias for clarity. @@ -67,7 +67,7 @@ class Exponential(FormulaConfig): def build_dynamic_params(self, t: chex.Numeric) -> DynamicExponential: return DynamicExponential( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicExponential, t=t, @@ -101,7 +101,7 @@ class Gaussian: def build_dynamic_params(self, t: chex.Numeric) -> DynamicGaussian: return DynamicGaussian( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicGaussian, t=t, diff --git a/torax/sources/formulas.py b/torax/sources/formulas.py index cc4a0577..523a34e5 100644 --- a/torax/sources/formulas.py +++ b/torax/sources/formulas.py @@ -17,10 +17,10 @@ import dataclasses import jax from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import jax_utils from torax import state +from torax.config import runtime_params_slice from torax.sources import formula_config from torax.sources import runtime_params @@ -114,7 +114,7 @@ class Exponential: def __call__( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None, @@ -136,7 +136,7 @@ class Gaussian: def __call__( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None, diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index 717ef2bc..94941eeb 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -20,10 +20,10 @@ import jax from jax import numpy as jnp -from torax import config_slice from torax import constants from torax import geometry from torax import state +from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib from torax.sources import source @@ -121,7 +121,7 @@ def calc_fusion( def fusion_heat_model_func( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, @@ -129,7 +129,7 @@ def fusion_heat_model_func( del dynamic_source_runtime_params # Unused. # pylint: disable=invalid-name _, Pfus_i, Pfus_e = calc_fusion( - geo, core_profiles, dynamic_config_slice.numerics.nref + geo, core_profiles, dynamic_runtime_params_slice.numerics.nref ) return jnp.stack((Pfus_i, Pfus_e)) # pylint: enable=invalid-name diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index 48799d6b..a96b2592 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -21,10 +21,10 @@ import chex import jax from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import state -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib from torax.sources import source @@ -50,7 +50,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -102,13 +102,13 @@ def calc_generic_heat_source( def _default_formula( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> jnp.ndarray: """Returns the default formula-based ion/electron heat source profile.""" - del dynamic_config_slice, core_profiles # Unused. + del dynamic_runtime_params_slice, core_profiles # Unused. assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) ion, el = calc_generic_heat_source( geo, diff --git a/torax/sources/qei_source.py b/torax/sources/qei_source.py index 3af886af..3ecf4ad4 100644 --- a/torax/sources/qei_source.py +++ b/torax/sources/qei_source.py @@ -21,11 +21,11 @@ import chex import jax from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import physics from torax import state -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_profiles @@ -41,7 +41,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -86,8 +86,8 @@ class QeiSource(source.Source): def get_qei( self, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, @@ -98,8 +98,8 @@ def get_qei( dynamic_source_runtime_params.mode == runtime_params_lib.Mode.MODEL_BASED.value, lambda: _model_based_qei( - static_config_slice, - dynamic_config_slice, + static_runtime_params_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, geo, core_profiles, @@ -110,7 +110,7 @@ def get_qei( def get_value( self, source_type: int, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, ) -> source_profiles.QeiInfo: @@ -126,8 +126,8 @@ def get_source_profile_for_affected_core_profile( def _model_based_qei( - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, @@ -137,8 +137,8 @@ def _model_based_qei( zeros = jnp.zeros_like(geo.r_norm) qei_coef = physics.coll_exchange( core_profiles=core_profiles, - nref=dynamic_config_slice.numerics.nref, - Ai=dynamic_config_slice.plasma_composition.Ai, + nref=dynamic_runtime_params_slice.numerics.nref, + Ai=dynamic_runtime_params_slice.plasma_composition.Ai, Qei_mult=dynamic_source_runtime_params.Qei_mult, ) implicit_ii = -qei_coef @@ -146,9 +146,13 @@ def _model_based_qei( if ( # if only a single heat equation is being evolved - (static_config_slice.ion_heat_eq and not static_config_slice.el_heat_eq) + ( + static_runtime_params_slice.ion_heat_eq + and not static_runtime_params_slice.el_heat_eq + ) or ( - static_config_slice.el_heat_eq and not static_config_slice.ion_heat_eq + static_runtime_params_slice.el_heat_eq + and not static_runtime_params_slice.ion_heat_eq ) ): explicit_i = qei_coef * core_profiles.temp_el.value diff --git a/torax/sources/runtime_params.py b/torax/sources/runtime_params.py index a2ee75c2..89bc5fcf 100644 --- a/torax/sources/runtime_params.py +++ b/torax/sources/runtime_params.py @@ -21,7 +21,7 @@ import chex from torax import interpolated_param -from torax.runtime_params import config_slice_args +from torax.config import config_args from torax.sources import formula_config @@ -50,7 +50,8 @@ class Mode(enum.Enum): class RuntimeParams: """Configures a single source/sink term. - This is a RUNTIME config, meaning its values can change from run to run + This is a RUNTIME runtime_params, meaning its values can change from run to + run without trigerring a recompile. This config defines the runtime config for the entire simulation run. The DynamicRuntimeParams, which is derived from this class, only contains information for a single time step. @@ -82,7 +83,7 @@ class RuntimeParams: def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, diff --git a/torax/sources/source.py b/torax/sources/source.py index 94a00b3a..5a258f62 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -29,17 +29,17 @@ import chex from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import jax_utils from torax import state +from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib # Sources implement these functions to be able to provide source profiles. SourceProfileFunction = Callable[ [ # Arguments - config_slice.DynamicConfigSlice, # General config params. + runtime_params_slice.DynamicRuntimeParamsSlice, # General config params runtime_params_lib.DynamicRuntimeParams, # Source-specific params. geometry.Geometry, state.CoreProfiles | None, @@ -49,8 +49,8 @@ ] -# Any callable which takes the dynamic config, geometry, and optional core -# profiles, and outputs a shape corresponding to the expected output of a +# Any callable which takes the dynamic runtime_params, geometry, and optional +# core profiles, and outputs a shape corresponding to the expected output of a # source. See how these types of functions are used in the Source class below. SourceOutputShapeFunction = Callable[ [ # Arguments @@ -186,7 +186,7 @@ def _unsupported_mode_error_msg( def get_value( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, @@ -194,8 +194,8 @@ def get_value( """Returns the profile for this source during one time step. Args: - dynamic_config_slice: Slice of the general TORAX config that can be used - as input for this time step. + dynamic_runtime_params_slice: Slice of the general TORAX config that can + be used as input for this time step. dynamic_source_runtime_params: Slice of this source's runtime parameters at a specific time t. geo: Geometry of the torus. @@ -222,7 +222,7 @@ def get_value( else self.formula ) return get_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, @@ -319,7 +319,8 @@ class SingleProfileSource(Source): If you want to create a subclass of SingleProfileSource with frozen parameters, you can provide default implementations/attributes. This is an example of a model-based source with a frozen custom model that cannot be - changed by a config, along with custom runtime parameters specific to this + changed by a runtime_params, along with custom runtime parameters specific to + this source: ```python @@ -330,7 +331,7 @@ class FooRuntimeParams(runtime_params_lib.RuntimeParams): def build_dynamic_params(self, t: chex.Numeric) -> DynamicFooRuntimeParams: return DynamicFooRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicFooRuntimeParams, t=t, @@ -343,7 +344,7 @@ class DynamicFooRuntimeParams(runtime_params_lib.DynamicRuntimeParams): bar_param: float def _my_foo_model( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, geo, core_profiles, @@ -385,7 +386,7 @@ class FooSource(SingleProfileSource): def get_value( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, @@ -393,7 +394,7 @@ def get_value( """Returns the profile for this source during one time step.""" output_shape = self.output_shape_getter(geo) profile = super().get_value( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, @@ -439,7 +440,7 @@ def get_zero_profile(self, geo: geometry.Geometry) -> jnp.ndarray: def get_source_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None, @@ -453,8 +454,8 @@ def get_source_profiles( source types will be ignored. Args: - dynamic_config_slice: Slice of the general TORAX config that can be used as - input for this time step. + dynamic_runtime_params_slice: Slice of the general TORAX config that can be + used as input for this time step. dynamic_source_runtime_params: Slice of this source's runtime parameters at a specific time t. geo: Geometry information. Used as input to the source profile functions. @@ -473,7 +474,7 @@ def get_source_profiles( output += jnp.where( mode == runtime_params_lib.Mode.MODEL_BASED.value, model_func( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, geo, core_profiles, @@ -483,7 +484,7 @@ def get_source_profiles( output += jnp.where( mode == runtime_params_lib.Mode.FORMULA_BASED.value, formula( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, geo, core_profiles, @@ -569,7 +570,7 @@ class IonElectronSource(Source): def get_value( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, @@ -577,8 +578,8 @@ def get_value( """Computes the ion and electron values of the source. Args: - dynamic_config_slice: Input config which can change from time step to time - step. + dynamic_runtime_params_slice: Input config which can change from time step + to time step. dynamic_source_runtime_params: Slice of this source's runtime parameters at a specific time t. geo: Geometry of the torus. @@ -590,7 +591,7 @@ def get_value( """ output_shape = self.output_shape_getter(geo) profile = super().get_value( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 15af48ca..13276755 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -20,12 +20,12 @@ import functools import jax.numpy as jnp -from torax import config_slice from torax import constants from torax import geometry from torax import jax_utils from torax import physics from torax import state +from torax.config import runtime_params_slice from torax.fvm import diffusion_terms from torax.sources import bootstrap_current_source from torax.sources import external_current_source @@ -42,7 +42,7 @@ ], ) def build_source_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: SourceModels, @@ -51,8 +51,8 @@ def build_source_profiles( """Builds explicit or implicit source profiles. Args: - dynamic_config_slice: Input config for this time step. Can change from time - step to time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if @@ -67,11 +67,11 @@ def build_source_profiles( """ # Bootstrap current is a special-case source with multiple outputs, so handle # it here. - dynamic_bootstrap_runtime_params = dynamic_config_slice.sources[ + dynamic_bootstrap_runtime_params = dynamic_runtime_params_slice.sources[ source_models.j_bootstrap_name ] bootstrap_profiles = _build_bootstrap_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, geo=geo, core_profiles=core_profiles, @@ -81,17 +81,25 @@ def build_source_profiles( other_profiles = {} other_profiles.update( _build_psi_profiles( - dynamic_config_slice, geo, core_profiles, source_models, explicit + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + explicit, ) ) other_profiles.update( _build_ne_profiles( - dynamic_config_slice, geo, core_profiles, source_models, explicit + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + explicit, ) ) other_profiles.update( _build_temp_ion_el_profiles( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, core_profiles, source_models, @@ -108,7 +116,7 @@ def build_source_profiles( def _build_bootstrap_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, @@ -119,8 +127,8 @@ def _build_bootstrap_profiles( """Computes the bootstrap current profile. Args: - dynamic_config_slice: Input config for this time step. Can change from time - step to time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. dynamic_source_runtime_params: Input runtime parameters for this time step, specific to the bootstrap current source. geo: Geometry of the torus. @@ -139,7 +147,7 @@ def _build_bootstrap_profiles( Bootstrap current profile. """ bootstrap_profile = j_bootstrap_source.get_value( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, @@ -185,7 +193,7 @@ def _build_bootstrap_profiles( def _build_psi_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: SourceModels, @@ -195,8 +203,8 @@ def _build_psi_profiles( """Computes psi sources and builds a kwargs dict for SourceProfiles. Args: - dynamic_config_slice: Input config for this time step. Can change from time - step to time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if @@ -214,7 +222,7 @@ def _build_psi_profiles( """ psi_profiles = {} # jext is precomputed in the core profiles. - dynamic_jext_runtime_params = dynamic_config_slice.sources[ + dynamic_jext_runtime_params = dynamic_runtime_params_slice.sources[ source_models.jext_name ] psi_profiles[source_models.jext_name] = jax_utils.select( @@ -228,14 +236,16 @@ def _build_psi_profiles( # Iterate through the rest of the sources and compute profiles for the ones # which relate to psi. jext is not part of the "standard sources." for source_name, source in source_models.psi_sources.items(): - dynamic_source_runtime_params = dynamic_config_slice.sources[source_name] + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] psi_profiles[source_name] = jax_utils.select( jnp.logical_or( explicit == dynamic_source_runtime_params.is_explicit, calculate_anyway, ), source.get_value( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, geo, core_profiles, @@ -246,7 +256,7 @@ def _build_psi_profiles( def _build_ne_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: SourceModels, @@ -255,8 +265,8 @@ def _build_ne_profiles( """Computes ne sources and builds a kwargs dict for SourceProfiles. Args: - dynamic_config_slice: Input config for this time step. Can change from time - step to time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if @@ -275,11 +285,13 @@ def _build_ne_profiles( # Iterate through the sources and compute profiles for the ones which relate # to ne. for source_name, source in source_models.ne_sources.items(): - dynamic_source_runtime_params = dynamic_config_slice.sources[source_name] + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] ne_profiles[source_name] = jax_utils.select( explicit == dynamic_source_runtime_params.is_explicit, source.get_value( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, geo, core_profiles, @@ -290,7 +302,7 @@ def _build_ne_profiles( def _build_temp_ion_el_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: SourceModels, @@ -299,8 +311,8 @@ def _build_temp_ion_el_profiles( """Computes ion and el sources and builds a kwargs dict for SourceProfiles. Args: - dynamic_config_slice: Input config for this time step. Can change from time - step to time step. + dynamic_runtime_params_slice: Input config for this time step. Can change + from time step to time step. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if @@ -322,11 +334,13 @@ def _build_temp_ion_el_profiles( ) for source_name, source in temp_ion_el_sources.items(): zeros = jnp.zeros(source.output_shape_getter(geo)) - dynamic_source_runtime_params = dynamic_config_slice.sources[source_name] + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] ion_el_profiles[source_name] = jax_utils.select( explicit == dynamic_source_runtime_params.is_explicit, source.get_value( - dynamic_config_slice, + dynamic_runtime_params_slice, dynamic_source_runtime_params, geo, core_profiles, @@ -406,7 +420,7 @@ def sum_sources_temp_el( def calc_and_sum_sources_psi( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: SourceModels, @@ -417,7 +431,7 @@ def calc_and_sum_sources_psi( # expensive source functions that might not jittable (like file-based or # RPC-based sources). psi_profiles = _build_psi_profiles( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, core_profiles, source_models, @@ -426,11 +440,11 @@ def calc_and_sum_sources_psi( total = 0 for key in psi_profiles: total += psi_profiles[key] - dynamic_bootstrap_runtime_params = dynamic_config_slice.sources[ + dynamic_bootstrap_runtime_params = dynamic_runtime_params_slice.sources[ source_models.j_bootstrap_name ] j_bootstrap_profiles = _build_bootstrap_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, geo=geo, core_profiles=core_profiles, @@ -451,7 +465,7 @@ def calc_and_sum_sources_psi( ], ) def calc_psidot( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: SourceModels, @@ -466,7 +480,7 @@ def calc_psidot( (but abridged) formulation as in sim.calc_coeffs and fvm._calc_c is used here Args: - dynamic_config_slice: Simulation configuration at this timestep + dynamic_runtime_params_slice: Simulation configuration at this timestep geo: Torus geometry core_profiles: Core plasma profiles. source_models: All TORAX source/sinks. @@ -477,14 +491,14 @@ def calc_psidot( consts = constants.CONSTANTS psi_sources, sigma = calc_and_sum_sources_psi( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, core_profiles, source_models, ) toc_psi = ( 1.0 - / dynamic_config_slice.numerics.resistivity_mult + / dynamic_runtime_params_slice.numerics.resistivity_mult * geo.r * sigma * consts.mu0 @@ -504,7 +518,7 @@ def calc_psidot( # OhmicHeatSource is a special case and defined here to avoid circular # dependencies, since it depends on the psi sources def _ohmic_heat_model( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, source_models: SourceModels, @@ -516,7 +530,7 @@ def _ohmic_heat_model( ) psidot = calc_psidot( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, core_profiles, source_models, @@ -573,14 +587,14 @@ class OhmicHeatSource(source_lib.SingleProfileSource): def __post_init__(self): # Ignore the model provided above and set it to the function here. def _model_func( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> jnp.ndarray: del dynamic_source_runtime_params return _ohmic_heat_model( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, source_models=self.source_models, diff --git a/torax/sources/tests/bootstrap_current_source.py b/torax/sources/tests/bootstrap_current_source.py index a1fe44c6..a76d421e 100644 --- a/torax/sources/tests/bootstrap_current_source.py +++ b/torax/sources/tests/bootstrap_current_source.py @@ -17,10 +17,10 @@ from absl.testing import absltest import jax.numpy as jnp import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import bootstrap_current_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -44,26 +44,30 @@ def setUpClass(cls): def test_source_value(self): source = bootstrap_current_source.BootstrapCurrentSource() - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) source_models = source_models_lib.SourceModels( sources={'j_bootstrap': source} ) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=dynamic_config_slice, - static_config_slice=static_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, geo=geo, source_models=source_models, ) self.assertIsNotNone( source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources[ + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ source_models.j_bootstrap_name ], geo=geo, @@ -79,8 +83,8 @@ def test_source_value(self): def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" source = bootstrap_current_source.BootstrapCurrentSource() - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) cell = source_lib.ProfileType.CELL.get_profile_shape(geo) face = source_lib.ProfileType.FACE.get_profile_shape(geo) fake_profile = source_profiles.BootstrapCurrentProfile( diff --git a/torax/sources/tests/external_current_source.py b/torax/sources/tests/external_current_source.py index ce8a02b7..97a9eb9b 100644 --- a/torax/sources/tests/external_current_source.py +++ b/torax/sources/tests/external_current_source.py @@ -18,9 +18,9 @@ import jax import jax.numpy as jnp import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import geometry +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import external_current_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -43,55 +43,57 @@ def setUpClass(cls): def test_source_value(self): """Tests that a formula-based source provides values.""" source = external_current_source.ExternalCurrentSource() - config = config_lib.Config() - dynamic_slice = config_slice.build_dynamic_config_slice( - config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + dynamic_slice = runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, sources={ 'jext': source.runtime_params, }, ) self.assertIsInstance(source, external_current_source.ExternalCurrentSource) # Must be circular for jext_hires call. - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) self.assertIsNotNone( source.get_value( - dynamic_config_slice=dynamic_slice, + dynamic_runtime_params_slice=dynamic_slice, dynamic_source_runtime_params=dynamic_slice.sources['jext'], geo=geo, ) ) self.assertIsNotNone( source.jext_hires( - dynamic_config_slice=dynamic_slice, + dynamic_runtime_params_slice=dynamic_slice, dynamic_source_runtime_params=dynamic_slice.sources['jext'], geo=geo, ) ) def test_invalid_source_types_raise_errors(self): - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) source = external_current_source.ExternalCurrentSource() for unsupported_mode in self._unsupported_modes: with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): source.runtime_params.mode = unsupported_mode - dynamic_slice = config_slice.build_dynamic_config_slice( - config, - sources={ - 'jext': source.runtime_params, - }, + dynamic_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources={ + 'jext': source.runtime_params, + }, + ) ) source.get_value( - dynamic_config_slice=dynamic_slice, + dynamic_runtime_params_slice=dynamic_slice, dynamic_source_runtime_params=dynamic_slice.sources['jext'], geo=geo, ) def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) source = external_current_source.ExternalCurrentSource() cell = source_lib.ProfileType.CELL.get_profile_shape(geo) fake_profile = (jnp.ones(cell), jnp.zeros(cell)) diff --git a/torax/sources/tests/formulas.py b/torax/sources/tests/formulas.py index 13d23c31..61fd7603 100644 --- a/torax/sources/tests/formulas.py +++ b/torax/sources/tests/formulas.py @@ -16,10 +16,10 @@ from absl.testing import absltest import chex -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib from torax import state as state_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import formula_config from torax.sources import formulas @@ -47,13 +47,13 @@ def test_custom_exponential_source_can_replace_puff_source(self): custom_source_name = 'custom_exponential_source' # Copy the test_particle_sources_constant config in here for clarity. - test_particle_sources_constant_config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + test_particle_sources_constant_runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, # This is important to be True to test ne sources. @@ -114,7 +114,7 @@ def test_custom_exponential_source_can_replace_puff_source(self): # way that does not trigger recompiles. This way we only trace the code # once. geo = geometry.build_circular_geometry( - test_particle_sources_constant_config + test_particle_sources_constant_runtime_params ) transport_model = constant_transport_model.ConstantTransportModel( runtime_params=constant_transport_model.RuntimeParams( @@ -123,7 +123,7 @@ def test_custom_exponential_source_can_replace_puff_source(self): ) ) sim = sim_lib.build_sim_from_config( - config=test_particle_sources_constant_config, + runtime_params=test_particle_sources_constant_runtime_params, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethodBuilder( runtime_params=linear_theta_method.LinearRuntimeParams( @@ -158,7 +158,7 @@ def test_custom_exponential_source_can_replace_puff_source(self): formula_config.Exponential( total=( S_puff_tot - / test_particle_sources_constant_config.numerics.nref + / test_particle_sources_constant_runtime_params.numerics.nref ), c1=1.0, c2=puff_decay_length, @@ -186,10 +186,10 @@ def _run_sim_and_check( ref_profiles: dict[str, chex.ArrayTree], ref_time: chex.Array, ): - """Runs sim with new dynamic config and checks the profiles vs. expected.""" + """Runs sim with new runtime params and checks the profiles vs. expected.""" torax_outputs = sim_lib.run_simulation( - static_config_slice=sim.static_config_slice, - dynamic_config_slice_provider=sim.dynamic_config_slice_provider, + static_runtime_params_slice=sim.static_runtime_params_slice, + dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider, geometry_provider=sim.geometry_provider, initial_state=sim.initial_state, time_step_calculator=sim.time_step_calculator, diff --git a/torax/sources/tests/fusion_heat_source.py b/torax/sources/tests/fusion_heat_source.py index 5df9c06b..d4bf7260 100644 --- a/torax/sources/tests/fusion_heat_source.py +++ b/torax/sources/tests/fusion_heat_source.py @@ -19,9 +19,9 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np -from torax import config_slice from torax import constants from torax import core_profile_setters +from torax.config import runtime_params_slice from torax.sources import fusion_heat_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source @@ -49,7 +49,9 @@ def setUpClass(cls): @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_calc_fusion( self, references_getter: Callable[[], torax_refs.References] @@ -57,18 +59,22 @@ def test_calc_fusion( """Compare `calc_fusion` function to a reference implementation.""" references = references_getter() - config = references.config + runtime_params = references.runtime_params geo = references.geo - nref = config.numerics.nref + nref = runtime_params.numerics.nref source_models = source_models_lib.SourceModels() - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) @@ -79,7 +85,7 @@ def test_calc_fusion( nref, ) - def calculate_fusion(config, geo, core_profiles): + def calculate_fusion(runtime_params, geo, core_profiles): """Reference implementation from PINT. We still use TORAX state here.""" # PINT doesn't follow Google style # pylint:disable=invalid-name @@ -114,14 +120,14 @@ def calculate_fusion(config, geo, core_profiles): Pfus = ( Efus * 0.25 - * (core_profiles.ni.face_value() * config.numerics.nref) ** 2 + * (core_profiles.ni.face_value() * runtime_params.numerics.nref) ** 2 * sigmav ) # [W/m^3] Ptot = np.trapz(Pfus * geo.vpr_face, geo.r_face) / 1e6 # [MW] return Ptot - fusion_pint = calculate_fusion(config, geo, core_profiles) + fusion_pint = calculate_fusion(runtime_params, geo, core_profiles) np.testing.assert_allclose(fusion_jax, fusion_pint) diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index d6f20a5f..ed75b738 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -17,10 +17,10 @@ import dataclasses from absl.testing import absltest import jax -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import qei_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -50,16 +50,18 @@ def test_source_value(self): source_models = source_models_lib.SourceModels( sources={'qei_source': source} ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) - static_slice = config_slice.build_static_config_slice(config) - dynamic_slice = config_slice.build_dynamic_config_slice( - config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) + static_slice = runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ) + dynamic_slice = runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, sources=source_models.runtime_params, ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_slice, - dynamic_config_slice=dynamic_slice, + static_runtime_params_slice=static_slice, + dynamic_runtime_params_slice=dynamic_slice, geo=geo, source_models=source_models, ) @@ -78,29 +80,35 @@ def test_invalid_source_types_raise_errors(self): source_models = source_models_lib.SourceModels( sources={'qei_source': source} ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) - static_slice = config_slice.build_static_config_slice(config) - dynamic_slice = config_slice.build_dynamic_config_slice( - config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) + static_slice = runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ) + dynamic_slice = runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, sources=source_models.runtime_params, ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_slice, - dynamic_config_slice=dynamic_slice, + static_runtime_params_slice=static_slice, + dynamic_runtime_params_slice=dynamic_slice, geo=geo, source_models=source_models, ) for unsupported_mode in self._unsupported_modes: with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): - dynamic_slice = config_slice.build_dynamic_config_slice( - config, - sources={ - 'qei_source': dataclasses.replace( - source.runtime_params, mode=unsupported_mode - ) - }, + dynamic_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources={ + 'qei_source': ( + dataclasses.replace( + source.runtime_params, mode=unsupported_mode + ) + ) + }, + ) ) source.get_qei( static_slice, diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index 4d82752d..774461bf 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -20,10 +20,10 @@ import jax from jax import numpy as jnp import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -41,22 +41,28 @@ def test_zero_profile_works_by_default(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) profile = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -80,24 +86,30 @@ def test_unsupported_modes_raise_errors(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) # But calling requesting ZERO shouldn't work. with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -115,32 +127,42 @@ def test_defaults_output_zeros(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) with self.subTest('model_based'): - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources={ - 'foo': dataclasses.replace( - source.runtime_params, - mode=runtime_params_lib.Mode.MODEL_BASED, - ) - }, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources={ + 'foo': ( + dataclasses.replace( + source.runtime_params, + mode=runtime_params_lib.Mode.MODEL_BASED, + ) + ) + }, + ) ) profile = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -149,18 +171,24 @@ def test_defaults_output_zeros(self): source_lib.ProfileType.CELL.get_zero_profile(geo), ) with self.subTest('formula'): - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources={ - 'foo': dataclasses.replace( - source.runtime_params, - mode=runtime_params_lib.Mode.FORMULA_BASED, - ) - }, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources={ + 'foo': ( + dataclasses.replace( + source.runtime_params, + mode=runtime_params_lib.Mode.FORMULA_BASED, + ) + ) + }, + ) ) profile = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -185,22 +213,28 @@ def test_overriding_default_formula(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) profile = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -223,22 +257,28 @@ def test_overriding_model(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) profile = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -257,10 +297,10 @@ def test_retrieving_profile_for_affected_state(self): source_lib.AffectedCoreProfile.NE, ), ) - config = config_lib.Config( - numerics=config_lib.Numerics(nr=4), + runtime_params = general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics(nr=4), ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) psi_profile = source.get_source_profile_for_affected_core_profile( profile, source_lib.AffectedCoreProfile.PSI.value, geo ) @@ -284,11 +324,11 @@ class SingleProfileSourceTest(parameterized.TestCase): def test_custom_formula(self): """The user-specified formula should override the default formula.""" - config = config_lib.Config( - numerics=config_lib.Numerics(nr=5), + runtime_params = general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics(nr=5), ) - geo = geometry.build_circular_geometry(config) - expected_output = jnp.ones(5) # 5 matches config.numerics.nr. + geo = geometry.build_circular_geometry(runtime_params) + expected_output = jnp.ones(5) # 5 matches runtime_params.numerics.nr. source = source_lib.SingleProfileSource( formula=lambda _0, _1, _2, _3: expected_output, affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), @@ -297,19 +337,25 @@ def test_custom_formula(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) profile = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -328,26 +374,32 @@ def test_multiple_profiles_raises_error(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - config = config_lib.Config( - numerics=config_lib.Numerics(nr=5), + runtime_params = general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics(nr=5), + ) + geo = geometry.build_circular_geometry(runtime_params) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) ) - geo = geometry.build_circular_geometry(config) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, # defaults are enough for this. source_models=source_models, ) with self.assertRaises(AssertionError): source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -360,10 +412,10 @@ def test_retrieving_profile_for_affected_state(self): model_func=lambda _0, _1, _2, _3: profile, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) - config = config_lib.Config( - numerics=config_lib.Numerics(nr=4), + runtime_params = general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics(nr=4), ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) psi_profile = source.get_source_profile_for_affected_core_profile( profile, source_lib.AffectedCoreProfile.PSI.value, geo ) diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index 1c934b0e..6283963f 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -20,9 +20,9 @@ import jax.numpy as jnp import numpy as np import torax # useful for setting up jax properly. -from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.config import runtime_params_slice from torax.sources import default_sources from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -43,30 +43,42 @@ class SourceProfilesTest(parameterized.TestCase): def test_computing_source_profiles_works_with_all_defaults(self): """Tests that you can compute source profiles with all defaults.""" - config = torax.Config() - geo = torax.build_circular_geometry(config) + runtime_params = torax.GeneralRuntimeParams() + geo = torax.build_circular_geometry(runtime_params) source_models = source_models_lib.SourceModels() - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) _ = source_models_lib.build_source_profiles( - dynamic_config_slice, geo, core_profiles, source_models, explicit=True + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + explicit=True, ) _ = source_models_lib.build_source_profiles( - dynamic_config_slice, geo, core_profiles, source_models, explicit=False + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + explicit=False, ) def test_summed_temp_ion_profiles_dont_change_when_jitting(self): """Test that sum_sources_temp_{ion|el} works with jitting.""" - config = torax.Config() - geo = torax.build_circular_geometry(config) + runtime_params = torax.GeneralRuntimeParams() + geo = torax.build_circular_geometry(runtime_params) # Use the default sources where the generic_ion_el_heat_source, # fusion_heat_source, and ohmic_heat_source are included and produce @@ -144,22 +156,26 @@ def foo_formula( source_models = source_models_lib.SourceModels( sources={source_name: foo_source}, ) - config = torax.Config() - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + runtime_params = torax.GeneralRuntimeParams() + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) - geo = torax.build_circular_geometry(config) + geo = torax.build_circular_geometry(runtime_params) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) def compute_and_sum_profiles(): profiles = source_models_lib.build_source_profiles( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, source_models=source_models, diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 3909bf9e..29adb824 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -21,10 +21,10 @@ import jax import jax.numpy as jnp import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -82,26 +82,32 @@ def test_source_value(self): source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa self.assertIsInstance(source, source_lib.SingleProfileSource) - config = config_lib.Config() + runtime_params = general_runtime_params.GeneralRuntimeParams() source.runtime_params.mode = source.supported_modes[0] source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - geo = geometry.build_circular_geometry(config) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - sources=source_models.runtime_params, + geo = geometry.build_circular_geometry(runtime_params) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) value = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -109,8 +115,8 @@ def test_source_value(self): def test_invalid_source_types_raise_errors(self): """Tests that using unsupported types raises an error.""" - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) # pylint: disable=missing-kwoa source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa @@ -118,27 +124,35 @@ def test_invalid_source_types_raise_errors(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) for unsupported_mode in self._unsupported_modes: source.runtime_params.mode = unsupported_mode - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + sources=source_models.runtime_params, + ) ) with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -153,25 +167,31 @@ def test_source_value(self): source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa self.assertIsInstance(source, source_lib.IonElectronSource) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - sources=source_models.runtime_params, + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) ion_and_el = source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) @@ -179,8 +199,8 @@ def test_source_value(self): def test_invalid_source_types_raise_errors(self): """Tests that using unsupported types raises an error.""" - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) # pylint: disable=missing-kwoa source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa @@ -188,36 +208,44 @@ def test_invalid_source_types_raise_errors(self): source_models = source_models_lib.SourceModels( sources={'foo': source}, ) - static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - sources=source_models.runtime_params, + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) for unsupported_mode in self._unsupported_modes: source.runtime_params.mode = unsupported_mode - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + sources=source_models.runtime_params, + ) ) with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): source.get_value( - dynamic_config_slice=dynamic_config_slice, - dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ + 'foo' + ], geo=geo, core_profiles=core_profiles, ) def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) # pylint: disable=missing-kwoa source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa diff --git a/torax/spectators/tests/plotting.py b/torax/spectators/tests/plotting.py index 9f492622..98a337ad 100644 --- a/torax/spectators/tests/plotting.py +++ b/torax/spectators/tests/plotting.py @@ -17,8 +17,8 @@ from absl.testing import absltest from absl.testing import parameterized import torax # We want this import to make sure jax gets set to float64 -from torax import config as config_lib from torax import geometry +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.spectators import plotting from torax.spectators import spectator @@ -31,12 +31,12 @@ class PlottingTest(parameterized.TestCase): """Tests the plotting library.""" def test_default_plot_config_has_valid_keys(self): - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) plot_config = plotting.get_default_plot_config(geo) observer = spectator.InMemoryJaxArraySpectator() - _run_sim(config, geo, observer) + _run_sim(runtime_params, geo, observer) # Make sure all the keys in plot_config are collected by the observer. for plot in plot_config: @@ -44,21 +44,21 @@ def test_default_plot_config_has_valid_keys(self): self.assertIn(key.key, observer.arrays) def test_plot_observer_runs_with_sim(self): - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) observer = plotting.PlotSpectator( plots=plotting.get_default_plot_config(geo), ) - _run_sim(config, geo, observer) + _run_sim(runtime_params, geo, observer) def _run_sim( - config: config_lib.Config, + runtime_params: general_runtime_params.GeneralRuntimeParams, geo: geometry.Geometry, observer: spectator.Spectator, ): torax.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethodBuilder(), transport_model=constant_transport_model.ConstantTransportModel(), diff --git a/torax/state.py b/torax/state.py index f4886a47..b4062942 100644 --- a/torax/state.py +++ b/torax/state.py @@ -22,8 +22,8 @@ import chex import jax from jax import numpy as jnp -from torax import config from torax import geometry +from torax.config import config_args from torax.fvm import cell_variable from torax.sources import source_profiles @@ -107,7 +107,7 @@ def index(self, i: int) -> CoreProfiles: history_vars = ["temp_ion", "temp_el", "psi", "psidot", "ne", "ni"] history_replace = {"history": None} replace_dict = {var: history_replace for var in history_vars} - state = config.recursive_replace(state, **replace_dict) + state = config_args.recursive_replace(state, **replace_dict) return state def sanity_check(self): diff --git a/torax/stepper/linear_theta_method.py b/torax/stepper/linear_theta_method.py index 40d16fe6..cb656095 100644 --- a/torax/stepper/linear_theta_method.py +++ b/torax/stepper/linear_theta_method.py @@ -18,11 +18,11 @@ import dataclasses from typing import Type import jax -from torax import config_slice from torax import fvm from torax import geometry from torax import sim from torax import state +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.sources import source_profiles from torax.stepper import predictor_corrector_method @@ -46,9 +46,9 @@ def __init__( def _x_new( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -69,7 +69,7 @@ def _x_new( # Instantiate coeffs_callback class coeffs_callback = self.callback_class( - static_config_slice=static_config_slice, + static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -82,11 +82,15 @@ def _x_new( # Compute the explicit coeffs based on the core profiles at time t and all # runtime parameters at time t. coeffs_exp = coeffs_callback( - dynamic_config_slice_t, x_old, allow_pereverzev=True, explicit_call=True + dynamic_runtime_params_slice_t, + x_old, + allow_pereverzev=True, + explicit_call=True, ) # Calculate x_new with the predictor corrector method. Reverts to a - # standard linear solve if static_config_slice.predictor_corrector=False. + # standard linear solve if + # static_runtime_params_slice.predictor_corrector=False. # init_val is the initialization for the predictor_corrector loop. # Neither value impacts the final result, but needs to be the correct # type. x_new initialization (index 0) input is x_old for correct typing. @@ -106,8 +110,8 @@ def _x_new( x_new, (core_sources, core_transport) = ( predictor_corrector_method.predictor_corrector_method( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, x_old=x_old, init_val=init_val, coeffs_exp=coeffs_exp, diff --git a/torax/stepper/nonlinear_theta_method.py b/torax/stepper/nonlinear_theta_method.py index 8dd5e478..e02dd9d3 100644 --- a/torax/stepper/nonlinear_theta_method.py +++ b/torax/stepper/nonlinear_theta_method.py @@ -23,14 +23,14 @@ import chex import jax -from torax import config_slice from torax import fvm from torax import geometry from torax import sim from torax import state +from torax.config import config_args +from torax.config import runtime_params_slice from torax.fvm import newton_raphson_solve_block from torax.fvm import optimizer_solve_block -from torax.runtime_params import config_slice_args from torax.sources import source_models as source_models_lib from torax.sources import source_profiles from torax.stepper import runtime_params as runtime_params_lib @@ -65,9 +65,9 @@ def __init__( def _x_new( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -82,7 +82,7 @@ def _x_new( """See Stepper._x_new docstring.""" coeffs_callback = self.callback_class( - static_config_slice=static_config_slice, + static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -93,9 +93,9 @@ def _x_new( ) x_new, core_sources, core_transport, error = self._x_new_helper( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t=dynamic_config_slice_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -110,9 +110,9 @@ def _x_new( def _x_new_helper( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -163,9 +163,9 @@ def __init__( def _x_new_helper( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -183,9 +183,9 @@ def _x_new_helper( x_new, error, (core_sources, core_transport) = ( optimizer_solve_block.optimizer_solve_block( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t=dynamic_config_slice_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, x_old=tuple([core_profiles_t[name] for name in evolving_names]), core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -250,7 +250,7 @@ def build_dynamic_params( self, t: chex.Numeric ) -> DynamicNewtonRaphsonRuntimeParams: return DynamicNewtonRaphsonRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicNewtonRaphsonRuntimeParams, t=t, @@ -302,9 +302,9 @@ def __init__( def _x_new_helper( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -319,7 +319,8 @@ def _x_new_helper( ]: """Final implementation of x_new after callback has been created etc.""" assert isinstance( - dynamic_config_slice_t.stepper, DynamicNewtonRaphsonRuntimeParams + dynamic_runtime_params_slice_t.stepper, + DynamicNewtonRaphsonRuntimeParams, ) # disable error checking in residual, since Newton-Raphson routine has # error checking based on result of each linear step @@ -328,9 +329,9 @@ def _x_new_helper( x_new, error, (core_sources, core_transport) = ( newton_raphson_solve_block.newton_raphson_solve_block( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t=dynamic_config_slice_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, x_old=tuple([core_profiles_t[name] for name in evolving_names]), core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -339,7 +340,7 @@ def _x_new_helper( source_models=self.source_models, coeffs_callback=coeffs_callback, evolving_names=evolving_names, - log_iterations=dynamic_config_slice_t.stepper.log_iterations, + log_iterations=dynamic_runtime_params_slice_t.stepper.log_iterations, initial_guess_mode=self.initial_guess_mode, maxiter=self.maxiter, tol=self.tol, diff --git a/torax/stepper/predictor_corrector_method.py b/torax/stepper/predictor_corrector_method.py index 6641fc6f..2d7f51e4 100644 --- a/torax/stepper/predictor_corrector_method.py +++ b/torax/stepper/predictor_corrector_method.py @@ -15,26 +15,24 @@ """Carries out the predictor corrector method for the PDE solution. Picard iterations to approximate the nonlinear solution. If -static_config_slice.stepper.predictor_corrector is False, reverts to a +static_runtime_params_slice.stepper.predictor_corrector is False, reverts to a standard linear solution. """ from typing import Any import chex import jax -from torax import config_slice from torax import fvm from torax import jax_utils +from torax.config import runtime_params_slice from torax.fvm import implicit_solve_block def predictor_corrector_method( dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, x_old: tuple[fvm.CellVariable, ...], - init_val: tuple[ - tuple[fvm.CellVariable, ...], chex.ArrayTree - ], + init_val: tuple[tuple[fvm.CellVariable, ...], chex.ArrayTree], coeffs_exp: fvm.block_1d_coeffs.Block1DCoeffs, coeffs_callback: fvm.block_1d_coeffs.Block1DCoeffsCallback, ) -> tuple[tuple[fvm.CellVariable, ...], Any]: @@ -42,10 +40,11 @@ def predictor_corrector_method( Args: dt: current timestep - static_config_slice: General input parameters which are fixed through a - simulation run, and if changed, would trigger a recompile. - dynamic_config_slice_t_plus_dt: Dynamic config parameters corresponding to - the next time step, needed for the implicit PDE coefficients + static_runtime_params_slice: General input parameters which are fixed + through a simulation run, and if changed, would trigger a recompile. + dynamic_runtime_params_slice_t_plus_dt: Dynamic runtime parameters + corresponding to the next time step, needed for the implicit PDE + coefficients x_old: Tuple of CellVariables correspond to the evolving core profiles at time t. init_val: Initial guess for the predictor corrector output. @@ -64,7 +63,7 @@ def loop_body(i, val): # pylint: disable=unused-argument x_new_guess = val[0] coeffs_new = coeffs_callback( - dynamic_config_slice_t_plus_dt, + dynamic_runtime_params_slice_t_plus_dt, x_new_guess, allow_pereverzev=True, ) @@ -75,12 +74,12 @@ def loop_body(i, val): # pylint: disable=unused-argument x_new_guess=x_new_guess, coeffs_old=coeffs_exp, coeffs_new=coeffs_new, - theta_imp=static_config_slice.stepper.theta_imp, + theta_imp=static_runtime_params_slice.stepper.theta_imp, convection_dirichlet_mode=( - static_config_slice.stepper.convection_dirichlet_mode + static_runtime_params_slice.stepper.convection_dirichlet_mode ), convection_neumann_mode=( - static_config_slice.stepper.convection_neumann_mode + static_runtime_params_slice.stepper.convection_neumann_mode ), ) @@ -91,10 +90,10 @@ def loop_body(i, val): # pylint: disable=unused-argument # TORAX_COMPILATION_ENABLED=False. This logic is in jax.utils_py_fori_loop. # If the static predictor_corrector=False, then compilation is faster, so # we maintain this option. - if static_config_slice.stepper.predictor_corrector: + if static_runtime_params_slice.stepper.predictor_corrector: return jax_utils.py_fori_loop( 0, - dynamic_config_slice_t_plus_dt.stepper.corrector_steps + 1, + dynamic_runtime_params_slice_t_plus_dt.stepper.corrector_steps + 1, loop_body, init_val, ) diff --git a/torax/stepper/runtime_params.py b/torax/stepper/runtime_params.py index d773812d..1e3c1e58 100644 --- a/torax/stepper/runtime_params.py +++ b/torax/stepper/runtime_params.py @@ -21,11 +21,12 @@ import chex from torax import interpolated_param -from torax.runtime_params import config_slice_args +from torax.config import config_args # Type-alias for clarity. While the InterpolatedParams can vary across any -# field, in Config, we mainly use it to handle time-dependent parameters. +# field, in runtime_params, we mainly use it to handle time-dependent +# parameters. TimeDependentField = interpolated_param.InterpParamOrInterpParamInput @@ -83,7 +84,7 @@ def __post_init__(self): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -92,7 +93,7 @@ def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: def build_static_params(self) -> StaticRuntimeParams: return StaticRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=StaticRuntimeParams, ) diff --git a/torax/stepper/stepper.py b/torax/stepper/stepper.py index 8cd8b9c4..3394bbb5 100644 --- a/torax/stepper/stepper.py +++ b/torax/stepper/stepper.py @@ -21,11 +21,11 @@ import dataclasses import jax -from torax import config_slice from torax import core_profile_setters from torax import fvm from torax import geometry from torax import state +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.sources import source_profiles from torax.stepper import runtime_params as runtime_params_lib @@ -56,9 +56,9 @@ def __init__( def __call__( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -73,14 +73,14 @@ def __call__( Args: dt: Time step duration. - static_config_slice: Input params that trigger recompilation when they - change. These don't have to be JAX-friendly types and can be used in - control-flow logic. - dynamic_config_slice_t: Runtime configuration for time t (the start time - of the step). These config params can change from step to step without - triggering a recompilation. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt, - used for implicit calculations in the solver. + static_runtime_params_slice: Input params that trigger recompilation when + they change. These don't have to be JAX-friendly types and can be used + in control-flow logic. + dynamic_runtime_params_slice_t: Runtime parameters for time t (the start + time of the step). These runtime params can change from step to step + without triggering a recompilation. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + + dt, used for implicit calculations in the solver. geo: Geometry of the torus. core_profiles_t: Core plasma profiles at the beginning of the time step. core_profiles_t_plus_dt: Core plasma profiles which contain all available @@ -88,7 +88,7 @@ def __call__( evolving boundary conditions and prescribed time-dependent profiles that are not being evolved by the PDE system. explicit_source_profiles: Source profiles of all explicit sources (as - configured by the input config). All implicit source's profiles will be + configured by the input params). All implicit source's profiles will be set to 0 in this object. These explicit source profiles were calculated either based on the original core profiles at the start of the time step or were independent of the core profiles. Because they were calculated @@ -111,15 +111,15 @@ def __call__( # This base class method can be completely overriden by a subclass, but # most can make use of the boilerplate here and just implement `_x_new`. - # Use config to determine which variables to evolve + # Use runtime params to determine which variables to evolve evolving_names = [] - if static_config_slice.ion_heat_eq: + if static_runtime_params_slice.ion_heat_eq: evolving_names.append('temp_ion') - if static_config_slice.el_heat_eq: + if static_runtime_params_slice.el_heat_eq: evolving_names.append('temp_el') - if static_config_slice.current_eq: + if static_runtime_params_slice.current_eq: evolving_names.append('psi') - if static_config_slice.dens_eq: + if static_runtime_params_slice.dens_eq: evolving_names.append('ne') evolving_names = tuple(evolving_names) @@ -127,9 +127,9 @@ def __call__( if evolving_names: x_new, core_sources, core_transport, error = self._x_new( dt=dt, - static_config_slice=static_config_slice, - dynamic_config_slice_t=dynamic_config_slice_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, @@ -148,7 +148,7 @@ def __call__( core_profiles_t_plus_dt = ( core_profile_setters.update_evolving_core_profiles( x_new, - dynamic_config_slice_t_plus_dt, + dynamic_runtime_params_slice_t_plus_dt, core_profiles_t_plus_dt, evolving_names, ) @@ -164,9 +164,9 @@ def __call__( def _x_new( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -185,14 +185,14 @@ def _x_new( Args: dt: Time step duration. - static_config_slice: Input params that trigger recompilation when they - change. These don't have to be JAX-friendly types and can be used in - control-flow logic. - dynamic_config_slice_t: Runtime configuration for time t (the start time - of the step). These config params can change from step to step without - triggering a recompilation. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt, - used for implicit calculations in the solver. + static_runtime_params_slice: Input params that trigger recompilation when + they change. These don't have to be JAX-friendly types and can be used + in control-flow logic. + dynamic_runtime_params_slice_t: Runtime parameters for time t (the start + time of the step). These runtime params can change from step to step + without triggering a recompilation. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + + dt, used for implicit calculations in the solver. geo: Geometry of the torus. core_profiles_t: Core plasma profiles at the beginning of the time step. core_profiles_t_plus_dt: Core plasma profiles which contain all available diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index a1991c30..8dc80117 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -17,11 +17,12 @@ from absl.testing import absltest import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import constants from torax import core_profile_setters from torax import geometry +from torax.config import config_args +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib @@ -32,43 +33,49 @@ def test_setting_boundary_conditions(self): """Tests that setting boundary conditions works.""" # Boundary conditions can be time-dependent, but when creating the initial # state, we want to grab the boundary condition params at time 0. - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_right=27.7, Te_bound_right={0.0: 42.0, 1.0: 0.0}, - ne_bound_right=config_lib.InterpolationParam( + ne_bound_right=general_runtime_params.InterpolationParam( {0.0: 0.1, 0.1: 2.0}, - interpolation_mode=config_lib.InterpolationMode.STEP, + interpolation_mode=general_runtime_params.InterpolationMode.STEP, ), Ip={0.0: 5, 1.0: 7}, ), ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) source_models = source_models_lib.SourceModels() - static_config_slice = config_slice.build_static_config_slice(config) - initial_dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) + ) + initial_dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice, - initial_dynamic_config_slice, + static_runtime_params_slice, + initial_dynamic_runtime_params_slice, geo, source_models=source_models, ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, - t=0.5, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + t=0.5, + ) ) bc = core_profile_setters.compute_boundary_conditions( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, ) - updated = config_lib.recursive_replace(core_profiles, **bc) + updated = config_args.recursive_replace(core_profiles, **bc) psi_constraint = 6e6 * constants.CONSTANTS.mu0 / geo.G2_face[-1] * geo.rmax np.testing.assert_allclose(updated.temp_ion.right_face_constraint, 27.7) diff --git a/torax/tests/geometry.py b/torax/tests/geometry.py index d46d7442..fb421cae 100644 --- a/torax/tests/geometry.py +++ b/torax/tests/geometry.py @@ -21,8 +21,8 @@ import jax from jax import numpy as jnp import numpy as np -from torax import config as config_lib from torax import geometry +from torax.config import runtime_params as general_runtime_params class GeometryTest(parameterized.TestCase): @@ -52,25 +52,27 @@ def test_face_to_cell(self, nr, seed): def test_frozen(self): """Test that the Geometry class is frozen.""" - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) with self.assertRaises(dataclasses.FrozenInstanceError): geo.dr = 1.0 def test_geometry_can_be_input_to_jitted_function(self): """Test that the Geometry class can be input to a jitted function.""" + def foo(geo: geometry.Geometry): _ = geo # do nothing. + foo_jitted = jax.jit(foo) - config = config_lib.Config() + runtime_params = general_runtime_params.GeneralRuntimeParams() with self.subTest('CircularGeometry'): - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) # Make sure you can call the function with geo as an arg. foo_jitted(geo) with self.subTest('CHEASEGeometry'): - geo = geometry.build_chease_geometry(config) + geo = geometry.build_chease_geometry(runtime_params) # Make sure you can call the function with geo as an arg. foo_jitted(geo) diff --git a/torax/tests/physics.py b/torax/tests/physics.py index 54235015..d870d68f 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -19,11 +19,11 @@ from absl.testing import parameterized from jax import numpy as jnp import numpy as np -from torax import config_slice from torax import constants from torax import core_profile_setters from torax import geometry from torax import physics +from torax.config import runtime_params_slice from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.tests.test_lib import torax_refs @@ -35,7 +35,9 @@ class PhysicsTest(torax_refs.ReferenceValueTest): @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_calc_q_from_psi( self, references_getter: Callable[[], torax_refs.References] @@ -43,26 +45,28 @@ def test_calc_q_from_psi( """Compare `calc_q_from_psi` function to a reference implementation.""" references = references_getter() - config = references.config - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) + runtime_params = references.runtime_params + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice(runtime_params) + ) geo = references.geo # Dummy value for jtot for unit testing purposes. - jtot = jnp.ones(config.numerics.nr) + jtot = jnp.ones(runtime_params.numerics.nr) q_face_jax, q_cell_jax = physics.calc_q_from_jtot_psi( geo, references.psi, jtot, - dynamic_config_slice.numerics.q_correction_factor, + dynamic_runtime_params_slice.numerics.q_correction_factor, ) # Make ground truth - def calc_q_from_psi(config, geo): + def calc_q_from_psi(runtime_params, geo): """Reference implementation from PINT.""" consts = constants.CONSTANTS - iota = np.zeros(config.numerics.nr + 1) # on face grid - q = np.zeros(config.numerics.nr + 1) # on face grid + iota = np.zeros(runtime_params.numerics.nr + 1) # on face grid + q = np.zeros(runtime_params.numerics.nr + 1) # on face grid # We use the reference value of psi here because the original code # for calculating psi depends on FiPy, and we don't want to install that iota[1:] = np.abs( @@ -75,17 +79,17 @@ def calc_q_from_psi(config, geo): q[0] = ( 2 * geo.B0 / (consts.mu0 * jtot[0] * geo.Rmaj) ) # use on-axis definition of q (Wesson 2004, Eq 3.48) - q *= config.numerics.q_correction_factor + q *= runtime_params.numerics.q_correction_factor - def face_to_cell(config, face): - cell = np.zeros(config.numerics.nr) + def face_to_cell(runtime_params, face): + cell = np.zeros(runtime_params.numerics.nr) cell[:] = 0.5 * (face[1:] + face[:-1]) return cell - q_cell = face_to_cell(config, q) + q_cell = face_to_cell(runtime_params, q) return q, q_cell - q_face_np, q_cell_np = calc_q_from_psi(config, geo) + q_face_np, q_cell_np = calc_q_from_psi(runtime_params, geo) np.testing.assert_allclose(q_face_jax, q_face_np) np.testing.assert_allclose(q_cell_jax, q_cell_np) @@ -93,7 +97,9 @@ def face_to_cell(config, face): @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_update_psi_from_j( self, references_getter: Callable[[], torax_refs.References] @@ -101,26 +107,28 @@ def test_update_psi_from_j( """Compare `update_psi_from_j` function to a reference implementation.""" references = references_getter() - config = references.config + runtime_params = references.runtime_params source_models = source_models_lib.SourceModels() # Turn on the external current source. source_models.jext.runtime_params.mode = ( source_runtime_params.Mode.FORMULA_BASED ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, sources=source_models.runtime_params + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, sources=source_models.runtime_params + ) ) geo = references.geo # pylint: disable=protected-access if isinstance(geo, geometry.CircularGeometry): currents = core_profile_setters._prescribe_currents_no_bootstrap( - dynamic_config_slice, + dynamic_runtime_params_slice, geo, source_models=source_models, ) psi = core_profile_setters._update_psi_from_j( - dynamic_config_slice, geo, currents + dynamic_runtime_params_slice, geo, currents ).value elif isinstance(geo, geometry.CHEASEGeometry): psi = geo.psi_from_chease_Ip @@ -133,7 +141,9 @@ def test_update_psi_from_j( @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_calc_jtot_from_psi( self, references_getter: Callable[[], torax_refs.References] @@ -151,7 +161,9 @@ def test_calc_jtot_from_psi( @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_calc_s_from_psi( self, references_getter: Callable[[], torax_refs.References] diff --git a/torax/tests/sim.py b/torax/tests/sim.py index 85b23aa3..b5c88260 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -401,8 +401,8 @@ def test_fail(self): def test_no_op(self): """Tests that running the stepper with all equations off is a no-op.""" - config = torax.config.Config( - numerics=torax.config.Numerics( + runtime_params = torax.general_runtime_params.GeneralRuntimeParams( + numerics=torax.general_runtime_params.Numerics( t_final=0.1, ion_heat_eq=False, el_heat_eq=False, @@ -411,10 +411,10 @@ def test_no_op(self): ) time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() - geo = torax.build_circular_geometry(config) + geo = torax.build_circular_geometry(runtime_params) sim = sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethodBuilder(), transport_model=constant_transport_model.ConstantTransportModel(), @@ -429,7 +429,7 @@ def test_no_op(self): chex.assert_rank(t, 1) history_length = state_history.temp_ion.value.shape[0] self.assertEqual(history_length, t.shape[0]) - self.assertGreater(t[-1], config.numerics.t_final) + self.assertGreater(t[-1], runtime_params.numerics.t_final) for torax_profile in _ALL_PROFILES: profile_history = state_history[torax_profile] @@ -471,13 +471,13 @@ def test_observers_update_during_runs(self, stepper_builder_constructor): stepper_builder = stepper_builder_constructor() # Load config structure. config_module = self._get_config_module('test_explicit.py') - config = config_module.get_config() - geo = config_module.get_geometry(config) + runtime_params = config_module.get_runtime_params() + geo = config_module.get_geometry(runtime_params) time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() spectator = spectator_lib.InMemoryJaxArraySpectator() sim = sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=stepper_builder, transport_model=config_module.get_transport_model(), diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index e5cb26ba..7b9e8445 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -20,12 +20,12 @@ from absl.testing import absltest import chex -from torax import config as config_lib -from torax import config_slice from torax import geometry from torax import sim as sim_lib from torax import state as state_lib -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import default_sources from torax.sources import electron_density_sources from torax.sources import runtime_params as runtime_params_lib @@ -49,7 +49,7 @@ def test_custom_ne_source_can_replace_defaults(self): custom_source_name = 'custom_ne_source' def custom_source_formula( - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state_lib.CoreProfiles | None, @@ -83,17 +83,17 @@ def custom_source_formula( # pylint: disable=protected-access return ( electron_density_sources._calc_puff_source( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=puff_params, geo=geo, ) + electron_density_sources._calc_nbi_source( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=nbi_params, geo=geo, ) + electron_density_sources._calc_pellet_source( - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=pellet_params, geo=geo, ) @@ -155,13 +155,13 @@ def custom_source_formula( # Copy the test_particle_sources_constant config in here for clarity. # These are the common kwargs without any of the sources. - test_particle_sources_constant_config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + test_particle_sources_constant_runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, # This is important to be True to test ne sources. @@ -176,10 +176,10 @@ def custom_source_formula( 'test_particle_sources_constant.h5', _ALL_PROFILES ) geo = geometry.build_circular_geometry( - test_particle_sources_constant_config + test_particle_sources_constant_runtime_params ) sim = sim_lib.build_sim_from_config( - config=test_particle_sources_constant_config, + runtime_params=test_particle_sources_constant_runtime_params, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethodBuilder( runtime_params=linear_theta_method.LinearRuntimeParams( @@ -234,13 +234,13 @@ def _run_sim_and_check( ref_profiles: dict[str, chex.ArrayTree], ref_time: chex.Array, ): - """Runs sim with new dynamic config and checks the profiles vs. expected.""" + """Runs sim with new runtime params and checks the profiles vs. expected.""" torax_outputs = sim_lib.run_simulation( initial_state=sim.initial_state, step_fn=sim.step_fn, geometry_provider=sim.geometry_provider, - dynamic_config_slice_provider=sim.dynamic_config_slice_provider, - static_config_slice=sim.static_config_slice, + dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider, + static_runtime_params_slice=sim.static_runtime_params_slice, time_step_calculator=sim.time_step_calculator, ) state_history, _, _ = state_lib.build_history_from_states(torax_outputs) @@ -275,7 +275,7 @@ def build_dynamic_params( self, t: chex.Numeric ) -> _CustomSourceDynamicRuntimeParams: return _CustomSourceDynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=_CustomSourceDynamicRuntimeParams, t=t, diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index bf58fed3..b5832970 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -26,14 +26,14 @@ import chex from jax import numpy as jnp import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import geometry from torax import sim as sim_lib from torax import state as state_module +from torax.config import config_args +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.fvm import cell_variable -from torax.runtime_params import config_slice_args from torax.sources import default_sources from torax.sources import runtime_params as runtime_params_lib from torax.sources import source @@ -54,14 +54,18 @@ class SimOutputSourceProfilesTest(sim_test_case.SimTestCase): def test_merging_source_profiles(self): """Tests that the implicit and explicit source profiles merge correctly.""" - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) source_models = default_sources.get_default_sources() - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + sources=source_models.runtime_params, + ) + ) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) ) - static_config_slice = config_slice.build_static_config_slice(config) # Technically, the _merge_source_profiles() function should be called with # source profiles where, for every source, only one of the implicit or # explicit profiles has non-zero values. That is what makes the summing @@ -80,8 +84,8 @@ def test_merging_source_profiles(self): value=2.0, ) qei_core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=dynamic_config_slice, - static_config_slice=static_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, geo=geo, source_models=source_models, ) @@ -123,7 +127,7 @@ def test_first_and_last_source_profiles(self): # Create custom sources whose output profiles depend on Tiped. # This is not physically realistic, just for testing purposes. def custom_source_formula( - unused_dynamic_config, + unused_dynamic_runtime_params, source_conf, geo, unused_state, @@ -159,31 +163,35 @@ def custom_source_formula( ), } ) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) time_stepper = _FakeTimeStepCalculator() step_fn = _FakeSimulationStepFn(time_stepper, source_models) - dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( - config=config, - transport_getter=constant_transport_model.RuntimeParams, - sources_getter=lambda: source_models.runtime_params, - stepper_getter=stepper_runtime_params.RuntimeParams, + dynamic_runtime_params_slice_provider = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport_getter=constant_transport_model.RuntimeParams, + sources_getter=lambda: source_models.runtime_params, + stepper_getter=stepper_runtime_params.RuntimeParams, + ) + ) + initial_dcs = dynamic_runtime_params_slice_provider(0.0) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) ) - initial_dcs = dynamic_config_slice_provider(0.0) - static_config_slice = config_slice.build_static_config_slice(config) sim_states = sim_lib.run_simulation( initial_state=sim_lib.get_initial_state( - static_config_slice=static_config_slice, - dynamic_config_slice=initial_dcs, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=initial_dcs, geo=geo, time_step_calculator=time_stepper, source_models=source_models, ), step_fn=step_fn, geometry_provider=sim_lib.ConstantGeometryProvider(geo), - dynamic_config_slice_provider=dynamic_config_slice_provider, - static_config_slice=static_config_slice, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + static_runtime_params_slice=static_runtime_params_slice, time_step_calculator=time_stepper, ) @@ -240,14 +248,14 @@ def initial_state(self): def not_done( self, t: float | jnp.ndarray, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, state, ) -> bool | jnp.ndarray: return t < 2 def next_dt( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state_module.CoreProfiles, time_step_calculator_state, @@ -264,7 +272,7 @@ def build_dynamic_params( self, t: chex.Numeric ) -> _FakeSourceDynamicRuntimeParams: return _FakeSourceDynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=_FakeSourceDynamicRuntimeParams, t=t, @@ -298,14 +306,16 @@ def stepper(self): def __call__( self, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, geo: geometry.Geometry, input_state: state_module.ToraxSimState, explicit_source_profiles: source_profiles_lib.SourceProfiles, ) -> state_module.ToraxSimState: dt, ts_state = self._time_step_calculator.next_dt( - dynamic_config_slice=dynamic_config_slice_provider(input_state.t), + dynamic_runtime_params_slice=dynamic_runtime_params_slice_provider( + input_state.t + ), geo=geo, core_profiles=input_state.core_profiles, time_step_calculator_state=input_state.time_step_calculator_state, @@ -319,7 +329,9 @@ def __call__( time_step_calculator_state=ts_state, # The returned source profiles include only the implicit sources. core_sources=source_models_lib.build_source_profiles( - dynamic_config_slice=dynamic_config_slice_provider(new_t), + dynamic_runtime_params_slice=dynamic_runtime_params_slice_provider( + new_t + ), geo=geo, core_profiles=input_state.core_profiles, # no state evolution. source_models=self.stepper.source_models, diff --git a/torax/tests/sim_time_dependence.py b/torax/tests/sim_time_dependence.py index 8c0c8a79..af9441c9 100644 --- a/torax/tests/sim_time_dependence.py +++ b/torax/tests/sim_time_dependence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests torax.sim for handling time dependent input config params.""" +"""Tests torax.sim for handling time dependent input runtime params.""" import dataclasses @@ -21,11 +21,11 @@ import jax import jax.numpy as jnp import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import geometry from torax import sim as sim_lib from torax import state +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.sources import source_profiles from torax.stepper import runtime_params as stepper_runtime_params @@ -36,7 +36,7 @@ class SimWithTimeDependeceTest(parameterized.TestCase): - """Integration tests for torax.sim with time-dependent config params.""" + """Integration tests for torax.sim with time-dependent runtime params.""" @parameterized.named_parameters( ('with_adaptive_dt', True, 3, 0, 2.44444444444), @@ -50,17 +50,17 @@ def test_time_dependent_params_update_in_adaptive_dt( expected_combined_value: float, ): """Tests the SimulationStepFn's adaptive dt uses time-dependent params.""" - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_right={0.0: 1.0, 1.0: 2.0, 10.0: 11.0}, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( adaptive_dt=adaptive_dt, fixed_dt=1.0, # 1 time step in, the Ti_bound_right will be 2.0 dt_reduction_factor=1.5, ), ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) transport = FakeTransportModel() source_models = source_models_lib.SourceModels() # max combined value of Ti_bound_right should be 2.5. Higher will make the @@ -77,30 +77,36 @@ def test_time_dependent_params_update_in_adaptive_dt( time_calculator, transport_model=transport, ) - dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( - config=config, - transport_getter=lambda: transport.runtime_params, - sources_getter=lambda: source_models.runtime_params, - stepper_getter=stepper_runtime_params.RuntimeParams, + dynamic_runtime_params_slice_provider = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport_getter=lambda: transport.runtime_params, + sources_getter=lambda: source_models.runtime_params, + stepper_getter=stepper_runtime_params.RuntimeParams, + ) ) - initial_dynamic_config_slice = dynamic_config_slice_provider( - config.numerics.t_initial + initial_dynamic_runtime_params_slice = ( + dynamic_runtime_params_slice_provider(runtime_params.numerics.t_initial) ) input_state = sim_lib.get_initial_state( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=initial_dynamic_config_slice, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=initial_dynamic_runtime_params_slice, geo=geo, time_step_calculator=time_calculator, source_models=source_models, ) output_state = sim_step_fn( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice_provider=dynamic_config_slice_provider, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, geo=geo, input_state=input_state, explicit_source_profiles=source_models_lib.build_source_profiles( source_models=source_models, - dynamic_config_slice=initial_dynamic_config_slice, + dynamic_runtime_params_slice=initial_dynamic_runtime_params_slice, geo=geo, core_profiles=input_state.core_profiles, explicit=True, @@ -121,7 +127,8 @@ def test_time_dependent_params_update_in_adaptive_dt( class FakeStepper(stepper_lib.Stepper): """Fake stepper that allows us to hook into the error logic. - Given the name of a time-dependent param in the config, and a max value for + Given the name of a time-dependent param in the runtime_params, and a max + value for that param, this stepper returns a successful state if the config values for that param in the config at time t and config at time t+dt sum to less than max value. @@ -146,9 +153,9 @@ def __init__( def __call__( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -160,10 +167,12 @@ def __call__( int, ]: combined = getattr( - dynamic_config_slice_t.profile_conditions, self._param - ) + getattr(dynamic_config_slice_t_plus_dt.profile_conditions, self._param) + dynamic_runtime_params_slice_t.profile_conditions, self._param + ) + getattr( + dynamic_runtime_params_slice_t_plus_dt.profile_conditions, self._param + ) transport = self.transport_model( - dynamic_config_slice_t, geo, core_profiles_t + dynamic_runtime_params_slice_t, geo, core_profiles_t ) # Use Qei as a hacky way to extract what the combined value was. core_sources = source_models_lib.build_all_zero_profiles( @@ -198,7 +207,7 @@ def runtime_params( def _call_implementation( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> state.CoreTransport: diff --git a/torax/tests/state.py b/torax/tests/state.py index 9faa4bec..48a222a0 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -22,11 +22,12 @@ import jax from jax import numpy as jnp import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import geometry from torax import state +from torax.config import config_args +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.tests.test_lib import torax_refs @@ -41,21 +42,23 @@ def setUp(self): self.history_length = 2 source_models = source_models_lib.SourceModels() - def make_hist(config, geo): + def make_hist(runtime_params, geo): initial_counter = jnp.array(0) def scan_f(counter: jax.Array, _) -> tuple[jax.Array, state.CoreProfiles]: core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice( - config, sources=source_models.runtime_params + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, sources=source_models.runtime_params ), geo=geo, source_models=source_models, ) # Make one variable in the history track the value of the counter value = jnp.ones_like(core_profiles.temp_ion.value) * counter - core_profiles = config_lib.recursive_replace( + core_profiles = config_args.recursive_replace( core_profiles, temp_ion={'value': value} ) return counter + 1, core_profiles.history_elem() @@ -68,9 +71,9 @@ def scan_f(counter: jax.Array, _) -> tuple[jax.Array, state.CoreProfiles]: ) return history - def make_history(config, geo): + def make_history(runtime_params, geo): # Bind non-JAX arguments so it can be jitted - bound = functools.partial(make_hist, config, geo) + bound = functools.partial(make_hist, runtime_params, geo) return jax.jit(bound)() self._make_history = make_history @@ -78,7 +81,9 @@ def make_history(config, geo): @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_sanity_check( self, @@ -88,11 +93,11 @@ def test_sanity_check( references = references_getter() source_models = source_models_lib.SourceModels() basic_core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice( - references.config + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + references.runtime_params ), - dynamic_config_slice=config_slice.build_dynamic_config_slice( - references.config, sources=source_models.runtime_params + dynamic_runtime_params_slice=runtime_params_slice.build_dynamic_runtime_params_slice( + references.runtime_params, sources=source_models.runtime_params ), geo=references.geo, source_models=source_models, @@ -102,7 +107,9 @@ def test_sanity_check( @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_index( self, @@ -110,7 +117,7 @@ def test_index( ): """Test State.index.""" references = references_getter() - history = self._make_history(references.config, references.geo) + history = self._make_history(references.runtime_params, references.geo) for i in range(self.history_length): self.assertEqual(i, history.index(i).temp_ion.value[0]) @@ -118,7 +125,9 @@ def test_index( @parameterized.parameters([ dict(references_getter=torax_refs.circular_references), dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict(references_getter=torax_refs.chease_references_Ip_from_config), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), ]) def test_project( self, @@ -126,7 +135,7 @@ def test_project( ): """Test State.project.""" references = references_getter() - history = self._make_history(references.config, references.geo) + history = self._make_history(references.runtime_params, references.geo) seed = 20230421 rng_state = jax.random.PRNGKey(seed) @@ -150,23 +159,25 @@ def test_initial_boundary_condition_from_time_dependent_params(self): """Tests that the initial boundary conditions are set from the config.""" # Boundary conditions can be time-dependent, but when creating the initial # core profiles, we want to grab the boundary condition params at time 0. - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_right=27.7, Te_bound_right={0.0: 42.0, 1.0: 0.0}, - ne_bound_right=config_lib.InterpolationParam( + ne_bound_right=general_runtime_params.InterpolationParam( {0.0: 0.1, 1.0: 2.0}, - interpolation_mode=config_lib.InterpolationMode.STEP, + interpolation_mode=general_runtime_params.InterpolationMode.STEP, ), ), ) source_models = source_models_lib.SourceModels() core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice( - config, sources=source_models.runtime_params + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + runtime_params + ), + dynamic_runtime_params_slice=runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, sources=source_models.runtime_params ), - geo=geometry.build_circular_geometry(config), + geo=geometry.build_circular_geometry(runtime_params), source_models=source_models, ) np.testing.assert_allclose( @@ -182,33 +193,36 @@ def test_initial_boundary_condition_from_time_dependent_params(self): dict(geo_builder=geometry.build_chease_geometry), ]) def test_initial_psi_from_j( - self, geo_builder: Callable[[config_lib.Config], geometry.Geometry] + self, + geo_builder: Callable[ + [general_runtime_params.GeneralRuntimeParams], geometry.Geometry + ], ): """Tests expected behaviour of initial psi and current options.""" - config1 = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + config1 = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( initial_j_is_total_current=True, initial_psi_from_j=True, nu=2, ), ) - config2 = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + config2 = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( initial_j_is_total_current=False, initial_psi_from_j=True, nu=2, ), ) - config3 = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + config3 = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( initial_j_is_total_current=False, initial_psi_from_j=True, nu=2, ), ) # Needed to generate psi for bootstrap calculation - config3_helper = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + config3_helper = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( initial_j_is_total_current=True, initial_psi_from_j=True, nu=2, @@ -217,46 +231,52 @@ def test_initial_psi_from_j( geo = geo_builder(config1) source_models = source_models_lib.SourceModels() source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 - dcs1 = config_slice.build_dynamic_config_slice( + dcs1 = runtime_params_slice.build_dynamic_runtime_params_slice( config1, sources=source_models.runtime_params ) core_profiles1 = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config1), - dynamic_config_slice=dcs1, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + config1 + ), + dynamic_runtime_params_slice=dcs1, geo=geo, source_models=source_models, ) source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 - dcs2 = config_slice.build_dynamic_config_slice( + dcs2 = runtime_params_slice.build_dynamic_runtime_params_slice( config2, sources=source_models.runtime_params ) core_profiles2 = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config2), - dynamic_config_slice=dcs2, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + config2 + ), + dynamic_runtime_params_slice=dcs2, geo=geo, source_models=source_models, ) source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 source_models.jext.runtime_params.fext = 0.0 - dcs3 = config_slice.build_dynamic_config_slice( + dcs3 = runtime_params_slice.build_dynamic_runtime_params_slice( config3, sources=source_models.runtime_params ) core_profiles3 = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config3), - dynamic_config_slice=dcs3, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + config3 + ), + dynamic_runtime_params_slice=dcs3, geo=geo, source_models=source_models, ) source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 source_models.jext.runtime_params.fext = 0.0 - dcs3_helper = config_slice.build_dynamic_config_slice( + dcs3_helper = runtime_params_slice.build_dynamic_runtime_params_slice( config3_helper, sources=source_models.runtime_params ) core_profiles3_helper = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice( + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( config3_helper ), - dynamic_config_slice=dcs3_helper, + dynamic_runtime_params_slice=dcs3_helper, geo=geo, source_models=source_models, ) @@ -275,7 +295,7 @@ def test_initial_psi_from_j( # Calculate bootstrap current for config3 which doesn't zero it out source_models = source_models_lib.SourceModels() bootstrap_profile = source_models.j_bootstrap.get_value( - dynamic_config_slice=dcs3, + dynamic_runtime_params_slice=dcs3, dynamic_source_runtime_params=dcs3.sources[ source_models.j_bootstrap_name ], @@ -332,31 +352,35 @@ def test_initial_psi_from_j( def test_initial_psi_from_geo_noop_circular(self): """Tests expected behaviour of initial psi and current options.""" source_models = source_models_lib.SourceModels() - config1 = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + config1 = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( initial_psi_from_j=False, ), ) - dcs1 = config_slice.build_dynamic_config_slice( + dcs1 = runtime_params_slice.build_dynamic_runtime_params_slice( config1, sources=source_models.runtime_params ) - config2 = config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + config2 = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( initial_psi_from_j=True, ), ) - dcs2 = config_slice.build_dynamic_config_slice( + dcs2 = runtime_params_slice.build_dynamic_runtime_params_slice( config2, sources=source_models.runtime_params ) core_profiles1 = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config1), - dynamic_config_slice=dcs1, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + config1 + ), + dynamic_runtime_params_slice=dcs1, geo=geometry.build_circular_geometry(config1), source_models=source_models, ) core_profiles2 = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config2), - dynamic_config_slice=dcs2, + static_runtime_params_slice=runtime_params_slice.build_static_runtime_params_slice( + config2 + ), + dynamic_runtime_params_slice=dcs2, geo=geometry.build_circular_geometry(config2), source_models=source_models, ) diff --git a/torax/tests/test_data/compilation_benchmark.py b/torax/tests/test_data/compilation_benchmark.py index 7e0b4360..b85b7e39 100644 --- a/torax/tests/test_data/compilation_benchmark.py +++ b/torax/tests/test_data/compilation_benchmark.py @@ -21,9 +21,9 @@ is infeasible. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -31,18 +31,18 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (Greenwald fraction units) ne_bound_right=0.2, neped=1.0, nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -53,8 +53,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -106,10 +108,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() + runtime_params = get_runtime_params() return sim_lib.build_sim_from_config( - config=config, - geo=get_geometry(config), + runtime_params=runtime_params, + geo=get_geometry(runtime_params), stepper_builder=get_stepper_builder(), source_models=get_sources(), transport_model=get_transport_model(), diff --git a/torax/tests/test_data/default_config.py b/torax/tests/test_data/default_config.py index f3049e55..2be06e21 100644 --- a/torax/tests/test_data/default_config.py +++ b/torax/tests/test_data/default_config.py @@ -12,25 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Run with the default config_lib.""" +"""Run with the default general_runtime_params.""" -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config() + return general_runtime_params.GeneralRuntimeParams() -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -53,10 +55,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_absolute_jext.py b/torax/tests/test_data/test_absolute_jext.py index 242a5bc1..984a629c 100644 --- a/torax/tests/test_data/test_absolute_jext.py +++ b/torax/tests/test_data/test_absolute_jext.py @@ -21,9 +21,9 @@ Result should be the same as test_psi_and_heat since fext=0 is ignored. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -32,16 +32,16 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_left=8, Te_bound_left=8, # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( current_eq=True, resistivity_mult=100, # to shorten current diffusion time t_final=2, @@ -49,8 +49,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -89,10 +91,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_all_transport_crank_nicolson.py b/torax/tests/test_data/test_all_transport_crank_nicolson.py index 3be5db0a..c889ce5c 100644 --- a/torax/tests/test_data/test_all_transport_crank_nicolson.py +++ b/torax/tests/test_data/test_all_transport_crank_nicolson.py @@ -19,9 +19,9 @@ Veff model. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -30,10 +30,10 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - plasma_composition=config_lib.PlasmaComposition(), - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + plasma_composition=general_runtime_params.PlasmaComposition(), + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (using Greenwald fraction default) ne_bound_right=0.2, @@ -42,7 +42,7 @@ def get_config() -> config_lib.Config: # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -54,9 +54,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -113,10 +115,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_all_transport_fusion_qlknn.py b/torax/tests/test_data/test_all_transport_fusion_qlknn.py index 15d783b0..57780027 100644 --- a/torax/tests/test_data/test_all_transport_fusion_qlknn.py +++ b/torax/tests/test_data/test_all_transport_fusion_qlknn.py @@ -18,9 +18,9 @@ density. D_e scaled from chi_e """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -29,10 +29,10 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( # Like test16 but with fusion power - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (Greenwald fraction units) ne_bound_right=0.2, @@ -41,7 +41,7 @@ def get_config() -> config_lib.Config: # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -52,8 +52,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -101,10 +103,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_bootstrap.py b/torax/tests/test_data/test_bootstrap.py index 06152311..1a9c3d3c 100644 --- a/torax/tests/test_data/test_bootstrap.py +++ b/torax/tests/test_data/test_bootstrap.py @@ -17,9 +17,9 @@ Constant transport coefficient model """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,16 +28,16 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, nbar=0.85, # initial density (in Greenwald fraction units) # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -48,8 +48,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -98,10 +100,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_cgmheat.py b/torax/tests/test_data/test_cgmheat.py index 60c5f5eb..90f5ee55 100644 --- a/torax/tests/test_data/test_cgmheat.py +++ b/torax/tests/test_data/test_cgmheat.py @@ -18,9 +18,9 @@ CGM. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -29,16 +29,18 @@ from torax.transport_model import critical_gradient as cgm_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - numerics=config_lib.Numerics( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> cgm_transport_model.CriticalGradientModel: @@ -74,10 +76,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_chease.py b/torax/tests/test_data/test_chease.py index 03233495..353b825b 100644 --- a/torax/tests/test_data/test_chease.py +++ b/torax/tests/test_data/test_chease.py @@ -17,9 +17,9 @@ Ip from parameters. implicit, Ti+Te, no Pei, no pedestal, constant chi. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,21 +28,23 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ip=15, set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -92,10 +94,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_crank_nicolson.py b/torax/tests/test_data/test_crank_nicolson.py index a62a31ad..2f351046 100644 --- a/torax/tests/test_data/test_crank_nicolson.py +++ b/torax/tests/test_data/test_crank_nicolson.py @@ -18,9 +18,9 @@ just check that Crank-Nicolson doesn't deviate too far from that. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -29,21 +29,23 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -81,10 +83,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_exact_finaltime.py b/torax/tests/test_data/test_exact_finaltime.py index 2d09c4b8..3f14e614 100644 --- a/torax/tests/test_data/test_exact_finaltime.py +++ b/torax/tests/test_data/test_exact_finaltime.py @@ -14,9 +14,9 @@ """test_exact_t_final: tests deterministic t_final with exact_t_final = True.""" -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -25,16 +25,16 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_left=8, Te_bound_left=8, # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( current_eq=True, resistivity_mult=100, # to shorten current diffusion time t_final=2, @@ -43,8 +43,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -80,10 +82,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_explicit.py b/torax/tests/test_data/test_explicit.py index 6b22cd09..8a64a740 100644 --- a/torax/tests/test_data/test_explicit.py +++ b/torax/tests/test_data/test_explicit.py @@ -15,9 +15,9 @@ """Config for test_explicit. Basic test of explicit linear solver.""" import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -26,14 +26,14 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( dtmult=0.9, t_final=0.1, ion_heat_eq=True, @@ -42,8 +42,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -89,10 +91,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, source_models=get_sources(), transport_model=get_transport_model(), diff --git a/torax/tests/test_data/test_fixed_dt.py b/torax/tests/test_data/test_fixed_dt.py index 5a768b5b..bae6d960 100644 --- a/torax/tests/test_data/test_fixed_dt.py +++ b/torax/tests/test_data/test_fixed_dt.py @@ -14,9 +14,9 @@ """Config for testing fixed timestep.""" -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -26,11 +26,11 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config( - numerics=config_lib.Numerics( + return general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics( t_final=2, use_fixed_dt=True, fixed_dt=2e-2, @@ -38,8 +38,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -75,14 +77,14 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - sim_config = get_config() - geo = get_geometry(sim_config) - if sim_config.numerics.use_fixed_dt: + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) + if runtime_params.numerics.use_fixed_dt: time_step_calculator = fixed_time_step_calculator.FixedTimeStepCalculator() else: time_step_calculator = None return sim_lib.build_sim_from_config( - config=sim_config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_frozen_newton_raphson.py b/torax/tests/test_data/test_frozen_newton_raphson.py index 396736ec..800180df 100644 --- a/torax/tests/test_data/test_frozen_newton_raphson.py +++ b/torax/tests/test_data/test_frozen_newton_raphson.py @@ -19,9 +19,9 @@ import functools -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -30,19 +30,21 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_sources() -> source_models_lib.SourceModels: @@ -76,15 +78,15 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) stepper_builder = get_stepper_builder() stepper_builder.builder = functools.partial( sim_test_case.make_frozen_newton_raphson_stepper, - config=config, + runtime_params=runtime_params, ) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=stepper_builder, transport_model=constant_transport_model.ConstantTransportModel(), diff --git a/torax/tests/test_data/test_frozen_optimizer.py b/torax/tests/test_data/test_frozen_optimizer.py index 980212f0..1386cfc2 100644 --- a/torax/tests/test_data/test_frozen_optimizer.py +++ b/torax/tests/test_data/test_frozen_optimizer.py @@ -16,9 +16,9 @@ import functools -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,19 +28,21 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -79,17 +81,17 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) transport_model = get_transport_model() stepper_builder = get_stepper_builder() stepper_builder.builder = functools.partial( sim_test_case.make_frozen_optimizer_stepper, - config=config, + runtime_params=runtime_params, transport_params=transport_model.runtime_params, ) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=stepper_builder, transport_model=transport_model, diff --git a/torax/tests/test_data/test_fusion_power.py b/torax/tests/test_data/test_fusion_power.py index a98f48bb..8c7137f0 100644 --- a/torax/tests/test_data/test_fusion_power.py +++ b/torax/tests/test_data/test_fusion_power.py @@ -18,9 +18,9 @@ fusion power. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -29,10 +29,10 @@ from torax.transport_model import critical_gradient as cgm_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( # (Like test15, but with fusion power) - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (in Greenwald fraction units) ne_bound_right=0.2, @@ -41,7 +41,7 @@ def get_config() -> config_lib.Config: # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -52,8 +52,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> cgm_transport_model.CriticalGradientModel: @@ -103,10 +105,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_implicit.py b/torax/tests/test_data/test_implicit.py index 1c11e1c1..cb760c57 100644 --- a/torax/tests/test_data/test_implicit.py +++ b/torax/tests/test_data/test_implicit.py @@ -14,9 +14,9 @@ """test_implicit: implicit, Ti+Te, no Pei, no pedestal, constant chi.""" -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -25,19 +25,21 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -75,10 +77,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_implicit_short_optimizer.py b/torax/tests/test_data/test_implicit_short_optimizer.py index 81a5abae..fc872a0a 100644 --- a/torax/tests/test_data/test_implicit_short_optimizer.py +++ b/torax/tests/test_data/test_implicit_short_optimizer.py @@ -20,9 +20,9 @@ import functools -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -32,21 +32,23 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=0.1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -85,17 +87,17 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) transport_model = get_transport_model() stepper_builder = get_stepper_builder() stepper_builder.builder = functools.partial( sim_test_case.make_frozen_optimizer_stepper, - config=config, + runtime_params=runtime_params, transport_params=transport_model.runtime_params, ) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=stepper_builder, transport_model=transport_model, diff --git a/torax/tests/test_data/test_iterbaseline_mockup.py b/torax/tests/test_data/test_iterbaseline_mockup.py index 81dcc66f..98c40914 100644 --- a/torax/tests/test_data/test_iterbaseline_mockup.py +++ b/torax/tests/test_data/test_iterbaseline_mockup.py @@ -15,9 +15,9 @@ """ITER baseline approximately based on Mantica PPCF 2021.""" import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -26,16 +26,16 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - plasma_composition=config_lib.PlasmaComposition( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + plasma_composition=general_runtime_params.PlasmaComposition( # physical inputs Ai=2.5, # amu of main ion (if multiple isotope, make average) Zeff=1.74, # needed for qlknn and fusion power # effective impurity charge state assumed for matching dilution=0.862 Zimp=6.3623, ), - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( Ip=15, # total plasma current in MA # boundary + initial conditions for T and n Ti_bound_left=15, # initial condition ion temperature for r=0 @@ -56,7 +56,7 @@ def get_config() -> config_lib.Config: neped=0.68, # pedestal top electron density in units of nref Ped_top=0.93, # set ped top location in normalized radius ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( # simulation control t_final=10, # length of simulation time in seconds # 1/multiplication factor for sigma (conductivity) to reduce current @@ -82,9 +82,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=True, Rmaj=6.2, # major radius (R) in meters @@ -212,10 +214,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_iterhybrid_mockup.py b/torax/tests/test_data/test_iterhybrid_mockup.py index 97c6c081..21d7175f 100644 --- a/torax/tests/test_data/test_iterhybrid_mockup.py +++ b/torax/tests/test_data/test_iterhybrid_mockup.py @@ -15,9 +15,9 @@ """ITER hybrid scenario approximately based on van Mulders NF 2021.""" import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -26,19 +26,19 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # NOTE: This approach to building the config is changing. Over time more # parts of this config will be built with pure Python constructors in # `get_sim()`. - return config_lib.Config( - plasma_composition=config_lib.PlasmaComposition( + return general_runtime_params.GeneralRuntimeParams( + plasma_composition=general_runtime_params.PlasmaComposition( # physical inputs Ai=2.5, # amu of main ion (if multiple isotope, make average) Zeff=1.6, # needed for qlknn and fusion power # effective impurity charge state assumed for matching dilution=0.862. Zimp=10, ), - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( Ip=10.5, # total plasma current in MA # boundary + initial conditions for T and n Ti_bound_left=15, # initial condition ion temperature for r=0 @@ -58,7 +58,7 @@ def get_config() -> config_lib.Config: neped=0.62, # pedestal top electron density in units of nref Ped_top=0.9, # set ped top location in normalized radius ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( # simulation control t_final=5, # length of simulation time in seconds # 1/multiplication factor for sigma (conductivity) to reduce current @@ -85,9 +85,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=True, Rmaj=6.2, # major radius (R) in meters @@ -219,10 +221,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_iterhybrid_newton.py b/torax/tests/test_data/test_iterhybrid_newton.py index 8272e7a1..171cedab 100644 --- a/torax/tests/test_data/test_iterhybrid_newton.py +++ b/torax/tests/test_data/test_iterhybrid_newton.py @@ -19,9 +19,9 @@ """ import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -29,19 +29,19 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # NOTE: This approach to building the config is changing. Over time more # parts of this config will be built with pure Python constructors in # `get_sim()`. - return config_lib.Config( - plasma_composition=config_lib.PlasmaComposition( + return general_runtime_params.GeneralRuntimeParams( + plasma_composition=general_runtime_params.PlasmaComposition( # physical inputs Ai=2.5, # amu of main ion (if multiple isotope, make average) Zeff=1.6, # needed for qlknn and fusion power # effective impurity charge state assumed for matching dilution=0.862. Zimp=10, ), - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( Ip=10.5, # total plasma current in MA # boundary + initial conditions for T and n Ti_bound_left=15, # initial condition ion temperature for r=0 @@ -61,7 +61,7 @@ def get_config() -> config_lib.Config: neped=0.62, # pedestal top electron density in units of nref Ped_top=0.9, # set ped top location in normalized radius ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( # simulation control t_final=1, # length of simulation time in seconds # 1/multiplication factor for sigma (conductivity) to reduce current @@ -88,9 +88,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=True, Rmaj=6.2, # major radius (R) in meters @@ -236,10 +238,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector.py b/torax/tests/test_data/test_iterhybrid_predictor_corrector.py index cd3d1889..dd915310 100644 --- a/torax/tests/test_data/test_iterhybrid_predictor_corrector.py +++ b/torax/tests/test_data/test_iterhybrid_predictor_corrector.py @@ -15,9 +15,9 @@ """ITER hybrid scenario based (roughly) on van Mulders NF 2021.""" import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -26,19 +26,19 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # NOTE: This approach to building the config is changing. Over time more # parts of this config will be built with pure Python constructors in # `get_sim()`. - return config_lib.Config( - plasma_composition=config_lib.PlasmaComposition( + return general_runtime_params.GeneralRuntimeParams( + plasma_composition=general_runtime_params.PlasmaComposition( # physical inputs Ai=2.5, # amu of main ion (if multiple isotope, make average) Zeff=1.6, # needed for qlknn and fusion power # effective impurity charge state assumed for matching dilution=0.862. Zimp=10, ), - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( Ip=10.5, # total plasma current in MA # boundary + initial conditions for T and n Ti_bound_left=15, # initial condition ion temperature for r=0 @@ -58,7 +58,7 @@ def get_config() -> config_lib.Config: neped=0.62, # pedestal top electron density in units of nref Ped_top=0.9, # set ped top location in normalized radius ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( # simulation control t_final=5, # length of simulation time in seconds # 1/multiplication factor for sigma (conductivity) to reduce current @@ -85,9 +85,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=True, Rmaj=6.2, # major radius (R) in meters @@ -220,10 +222,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_iterhybrid_rampup.py b/torax/tests/test_data/test_iterhybrid_rampup.py index 869ccbf4..55b079c2 100644 --- a/torax/tests/test_data/test_iterhybrid_rampup.py +++ b/torax/tests/test_data/test_iterhybrid_rampup.py @@ -19,9 +19,9 @@ """ import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -30,19 +30,19 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # NOTE: This approach to building the config is changing. Over time more # parts of this config will be built with pure Python constructors in # `get_sim()`. - return config_lib.Config( - plasma_composition=config_lib.PlasmaComposition( + return general_runtime_params.GeneralRuntimeParams( + plasma_composition=general_runtime_params.PlasmaComposition( # physical inputs Ai=2.5, # amu of main ion (if multiple isotope, make average) Zeff=1.6, # needed for qlknn and fusion power # effective impurity charge state assumed for matching dilution=0.862. Zimp=10, ), - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( Ip={0: 3, 80: 10.5}, # total plasma current in MA # boundary + initial conditions for T and n Ti_bound_left=6, # initial condition ion temperature for r=0 @@ -66,7 +66,7 @@ def get_config() -> config_lib.Config: neped={0: 0.3, 80: 0.7}, Ped_top=0.9, # set ped top location in normalized radius ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( # simulation control t_final=80, # length of simulation time in seconds fixed_dt=2, @@ -95,9 +95,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=True, Rmaj=6.2, # major radius (R) in meters @@ -238,10 +240,10 @@ def get_stepper_builder() -> ( def get_sim() -> sim_lib.Sim: - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_ne_qlknn_deff_veff.py b/torax/tests/test_data/test_ne_qlknn_deff_veff.py index a906a45f..99b6de30 100644 --- a/torax/tests/test_data/test_ne_qlknn_deff_veff.py +++ b/torax/tests/test_data/test_ne_qlknn_deff_veff.py @@ -19,9 +19,9 @@ Veff model. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -30,9 +30,9 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (in Greenwald fraction units) ne_bound_right=0.2, @@ -41,7 +41,7 @@ def get_config() -> config_lib.Config: # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -52,9 +52,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -108,10 +110,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_ne_qlknn_defromchie.py b/torax/tests/test_data/test_ne_qlknn_defromchie.py index 914f7f56..209c1856 100644 --- a/torax/tests/test_data/test_ne_qlknn_defromchie.py +++ b/torax/tests/test_data/test_ne_qlknn_defromchie.py @@ -19,9 +19,9 @@ scaled from chi_e """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -30,9 +30,9 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (using default Greenwald fraction) ne_bound_right=0.2, @@ -41,7 +41,7 @@ def get_config() -> config_lib.Config: # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -52,9 +52,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -108,10 +110,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_newton_raphson_zeroiter.py b/torax/tests/test_data/test_newton_raphson_zeroiter.py index 2001c367..32a09b23 100644 --- a/torax/tests/test_data/test_newton_raphson_zeroiter.py +++ b/torax/tests/test_data/test_newton_raphson_zeroiter.py @@ -25,9 +25,9 @@ to tiny timesteps (and test timeouts) """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.fvm import enums from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params @@ -66,18 +66,18 @@ def make_linear_newton_raphson_stepper( ) -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_left=8, Te_bound_left=8, # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( current_eq=True, adaptive_dt=False, # to shorten current diffusion time for the test @@ -87,8 +87,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -127,10 +129,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, source_models=get_sources(), stepper_builder=get_stepper_builder(), diff --git a/torax/tests/test_data/test_ohmic_power.py b/torax/tests/test_data/test_ohmic_power.py index 3b90e76d..968a193d 100644 --- a/torax/tests/test_data/test_ohmic_power.py +++ b/torax/tests/test_data/test_ohmic_power.py @@ -17,9 +17,9 @@ Ip from parameters. implicit, Ti+Te, Pei low dens, no pedestal, constant chi """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,23 +28,25 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, nbar_is_fGW=True, nbar=0.5, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, resistivity_mult=100, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -89,10 +91,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_optimizer_zeroiter.py b/torax/tests/test_data/test_optimizer_zeroiter.py index 58e725e1..b344786d 100644 --- a/torax/tests/test_data/test_optimizer_zeroiter.py +++ b/torax/tests/test_data/test_optimizer_zeroiter.py @@ -19,9 +19,9 @@ using 0 iterations and an initial guess based on the linear solver. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.fvm import enums from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params @@ -60,18 +60,18 @@ def make_linear_optimizer_stepper( ) -def get_config() -> config_lib.Config: +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: # This config based approach is deprecated. # Over time more will be built with pure Python constructors in `get_sim`. - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_left=8, Te_bound_left=8, # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( current_eq=True, adaptive_dt=False, # to shorten current diffusion time for the test @@ -81,8 +81,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -119,10 +121,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, source_models=get_sources(), stepper_builder=get_stepper_builder(), diff --git a/torax/tests/test_data/test_particle_sources_cgm.py b/torax/tests/test_data/test_particle_sources_cgm.py index d010bfdc..2f5dd9c0 100644 --- a/torax/tests/test_data/test_particle_sources_cgm.py +++ b/torax/tests/test_data/test_particle_sources_cgm.py @@ -17,9 +17,9 @@ CGM transport model. Pedestal. Particle sources including NBI """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,9 +28,9 @@ from torax.transport_model import critical_gradient as cgm_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (in Greenwald fraction units) ne_bound_right=0.2, @@ -39,7 +39,7 @@ def get_config() -> config_lib.Config: # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -50,8 +50,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> cgm_transport_model.CriticalGradientModel: @@ -100,10 +102,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_particle_sources_constant.py b/torax/tests/test_data/test_particle_sources_constant.py index da1c9b5f..bec1049f 100644 --- a/torax/tests/test_data/test_particle_sources_constant.py +++ b/torax/tests/test_data/test_particle_sources_constant.py @@ -17,9 +17,9 @@ Constant transport coefficient model. Pedestal. Particle sources """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,16 +28,16 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (Greenwald fraction units) # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -49,8 +49,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -99,10 +101,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_pc_method_ne.py b/torax/tests/test_data/test_pc_method_ne.py index da10f747..435aa1c4 100644 --- a/torax/tests/test_data/test_pc_method_ne.py +++ b/torax/tests/test_data/test_pc_method_ne.py @@ -20,9 +20,9 @@ """ import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -31,11 +31,11 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( # DVeff = False, leads to a numerical instability in the particle channel # here - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (in Greenwald fraction units) ne_bound_right=0.2, @@ -44,7 +44,7 @@ def get_config() -> config_lib.Config: # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -56,9 +56,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -116,10 +118,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_pedestal.py b/torax/tests/test_data/test_pedestal.py index 52260c16..510de9e4 100644 --- a/torax/tests/test_data/test_pedestal.py +++ b/torax/tests/test_data/test_pedestal.py @@ -17,9 +17,9 @@ Implicit solver, Ti+Te, Pei standard dens, pedestal, constant chi. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,16 +28,18 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - numerics=config_lib.Numerics( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -72,10 +74,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_prescribed_timedependent_ne.py b/torax/tests/test_data/test_prescribed_timedependent_ne.py index 8af245ea..bb1b33a0 100644 --- a/torax/tests/test_data/test_prescribed_timedependent_ne.py +++ b/torax/tests/test_data/test_prescribed_timedependent_ne.py @@ -20,9 +20,9 @@ """ import dataclasses -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -31,16 +31,16 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_left=10, Te_bound_left=10, Ip={0: 5, 4: 15, 6: 12, 8: 12}, Tiped={0: 2, 4: 2, 6: 5, 8: 4}, Teped={0: 2, 4: 2, 6: 5, 8: 4}, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( current_eq=True, resistivity_mult=50, # to shorten current diffusion time for the test dtmult=150, @@ -51,9 +51,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -120,10 +122,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_psi_and_heat.py b/torax/tests/test_data/test_psi_and_heat.py index 52a175a4..b002fd73 100644 --- a/torax/tests/test_data/test_psi_and_heat.py +++ b/torax/tests/test_data/test_psi_and_heat.py @@ -18,9 +18,9 @@ pedestal, chi from qlknn. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -29,16 +29,16 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_left=8, Te_bound_left=8, # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( current_eq=True, resistivity_mult=100, # to shorten current diffusion time t_final=2, @@ -46,8 +46,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -83,10 +85,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_psi_heat_dens.py b/torax/tests/test_data/test_psi_heat_dens.py index 9c9265a6..a1384418 100644 --- a/torax/tests/test_data/test_psi_heat_dens.py +++ b/torax/tests/test_data/test_psi_heat_dens.py @@ -17,9 +17,9 @@ Constant transport coefficient model. Pedestal """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,16 +28,16 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=True, nbar=0.85, # initial density (in Greenwald fraction units) # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=True, el_heat_eq=True, dens_eq=True, @@ -48,8 +48,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -98,10 +100,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_psichease_ip_chease.py b/torax/tests/test_data/test_psichease_ip_chease.py index b02724dc..3b9b60f6 100644 --- a/torax/tests/test_data/test_psichease_ip_chease.py +++ b/torax/tests/test_data/test_psichease_ip_chease.py @@ -17,9 +17,9 @@ Ip from CHEASE. implicit, psi (current diffusion) only """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,12 +28,12 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=False, el_heat_eq=False, current_eq=True, @@ -43,9 +43,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=False, ) @@ -95,10 +97,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_psichease_ip_parameters.py b/torax/tests/test_data/test_psichease_ip_parameters.py index 2c6e31a8..e5760cfc 100644 --- a/torax/tests/test_data/test_psichease_ip_parameters.py +++ b/torax/tests/test_data/test_psichease_ip_parameters.py @@ -17,9 +17,9 @@ Ip from parameters. implicit, psi (current diffusion) only """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,12 +28,12 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=False, el_heat_eq=False, current_eq=True, @@ -43,9 +43,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -95,10 +97,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_psichease_prescribed_johm.py b/torax/tests/test_data/test_psichease_prescribed_johm.py index 61909ae4..0aaa2cb3 100644 --- a/torax/tests/test_data/test_psichease_prescribed_johm.py +++ b/torax/tests/test_data/test_psichease_prescribed_johm.py @@ -17,9 +17,9 @@ Ip from parameters. implicit, psi (current diffusion) only """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,15 +28,15 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, initial_psi_from_j=True, initial_j_is_total_current=False, nu=2, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=False, el_heat_eq=False, current_eq=True, @@ -46,9 +46,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -98,10 +100,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_psichease_prescribed_jtot.py b/torax/tests/test_data/test_psichease_prescribed_jtot.py index 62f0061f..0606bcd7 100644 --- a/torax/tests/test_data/test_psichease_prescribed_jtot.py +++ b/torax/tests/test_data/test_psichease_prescribed_jtot.py @@ -17,9 +17,9 @@ Ip from parameters. implicit, psi (current diffusion) only """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,15 +28,15 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, initial_psi_from_j=True, initial_j_is_total_current=True, nu=2, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=False, el_heat_eq=False, current_eq=True, @@ -46,9 +46,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -98,10 +100,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_psiequation.py b/torax/tests/test_data/test_psiequation.py index fcdadc8a..4fa55ea7 100644 --- a/torax/tests/test_data/test_psiequation.py +++ b/torax/tests/test_data/test_psiequation.py @@ -14,9 +14,9 @@ """Tests current diffusion implementation.""" -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -25,15 +25,15 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, # set flat Ohmic current to provide larger range of current evolution # for test nu=0, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( ion_heat_eq=False, el_heat_eq=False, current_eq=True, @@ -43,8 +43,10 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -81,10 +83,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_qei.py b/torax/tests/test_data/test_qei.py index 7c0d568a..65889a9a 100644 --- a/torax/tests/test_data/test_qei.py +++ b/torax/tests/test_data/test_qei.py @@ -17,9 +17,9 @@ Implicit, Ti+Te, Pei low dens, no pedestal, constant chi. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,20 +28,22 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, nbar=0.5, # Initial density in Greenwald fraction units ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> constant_transport_model.ConstantTransportModel: @@ -76,10 +78,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_qei_chease_highdens.py b/torax/tests/test_data/test_qei_chease_highdens.py index 086b67c3..3d24255e 100644 --- a/torax/tests/test_data/test_qei_chease_highdens.py +++ b/torax/tests/test_data/test_qei_chease_highdens.py @@ -17,9 +17,9 @@ Ip from parameters. implicit, Ti+Te, Pei high dens, no pedestal, constant chi """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -28,21 +28,23 @@ from torax.transport_model import constant as constant_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, nbar=1.0, # Initial density in Greenwald fraction units ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=1, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -90,10 +92,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_qlknnheat.py b/torax/tests/test_data/test_qlknnheat.py index c5412c43..d140ae4f 100644 --- a/torax/tests/test_data/test_qlknnheat.py +++ b/torax/tests/test_data/test_qlknnheat.py @@ -18,9 +18,9 @@ QLKNN. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -29,16 +29,18 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - numerics=config_lib.Numerics( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + numerics=general_runtime_params.Numerics( t_final=2, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: @@ -74,10 +76,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_semiimplicit_convection.py b/torax/tests/test_data/test_semiimplicit_convection.py index 0dc5a1d2..803a7003 100644 --- a/torax/tests/test_data/test_semiimplicit_convection.py +++ b/torax/tests/test_data/test_semiimplicit_convection.py @@ -19,9 +19,9 @@ Pei standard dens, chi from CGM. """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -30,23 +30,25 @@ from torax.transport_model import critical_gradient as cgm_transport_model -def get_config() -> config_lib.Config: - return config_lib.Config( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( # This is test_cgm_heat but modified to not use the pedestal feature, # to exercise the convection term at the boundary. This causes FiPy to # explode. The time was reduced compared to test_cgm_heat to avoid test # time bottlenecks - profile_conditions=config_lib.ProfileConditions( + profile_conditions=general_runtime_params.ProfileConditions( set_pedestal=False, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( t_final=0.5, ), ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: - return geometry.build_circular_geometry(config) +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: + return geometry.build_circular_geometry(runtime_params) def get_transport_model() -> cgm_transport_model.CriticalGradientModel: @@ -86,10 +88,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_data/test_timedependence.py b/torax/tests/test_data/test_timedependence.py index 2af0c8b7..37a427c5 100644 --- a/torax/tests/test_data/test_timedependence.py +++ b/torax/tests/test_data/test_timedependence.py @@ -19,9 +19,9 @@ pedestal, mocking up current-overshoot and an LH transition """ -from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.config import runtime_params as general_runtime_params from torax.sources import default_sources from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib @@ -30,16 +30,16 @@ from torax.transport_model import qlknn_wrapper -def get_config() -> config_lib.Config: - return config_lib.Config( - profile_conditions=config_lib.ProfileConditions( +def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: + return general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( Ti_bound_left=10, Te_bound_left=10, Ip={0: 5, 4: 15, 6: 12, 8: 12}, Tiped={0: 2, 4: 2, 6: 5, 8: 4}, Teped={0: 2, 4: 2, 6: 5, 8: 4}, ), - numerics=config_lib.Numerics( + numerics=general_runtime_params.Numerics( current_eq=True, resistivity_mult=50, # to shorten current diffusion time for the test dtmult=150, @@ -50,9 +50,11 @@ def get_config() -> config_lib.Config: ) -def get_geometry(config: config_lib.Config) -> geometry.Geometry: +def get_geometry( + runtime_params: general_runtime_params.GeneralRuntimeParams, +) -> geometry.Geometry: return geometry.build_chease_geometry( - config, + runtime_params, geometry_file="ITER_hybrid_citrin_equil_cheasedata.mat2cols", Ip_from_parameters=True, ) @@ -114,10 +116,10 @@ def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most # config taking place via constructor args in this function. - config = get_config() - geo = get_geometry(config) + runtime_params = get_runtime_params() + geo = get_geometry(runtime_params) return sim_lib.build_sim_from_config( - config=config, + runtime_params=runtime_params, geo=geo, stepper_builder=get_stepper_builder(), source_models=get_sources(), diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index 4b245feb..30c949cc 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -23,13 +23,13 @@ import jax from jax import numpy as jnp -from torax import config_slice from torax import constants from torax import core_profile_setters from torax import fvm from torax import geometry from torax import physics from torax import state +from torax.config import runtime_params_slice from torax.sources import source_models from torax.sources import source_profiles from torax.stepper import stepper as stepper_lib @@ -51,9 +51,9 @@ class ExplicitStepper(stepper_lib.Stepper): def __call__( self, dt: jax.Array, - static_config_slice: config_slice.StaticConfigSlice, - dynamic_config_slice_t: config_slice.DynamicConfigSlice, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, @@ -73,14 +73,14 @@ def __call__( # The explicit method is for testing purposes and # only implemented for ion heat. # Ensure that this is what the user requested. - assert static_config_slice.ion_heat_eq - assert not static_config_slice.el_heat_eq - assert not static_config_slice.dens_eq - assert not static_config_slice.current_eq + assert static_runtime_params_slice.ion_heat_eq + assert not static_runtime_params_slice.el_heat_eq + assert not static_runtime_params_slice.dens_eq + assert not static_runtime_params_slice.current_eq consts = constants.CONSTANTS - nref = dynamic_config_slice_t.numerics.nref + nref = dynamic_runtime_params_slice_t.numerics.nref true_ni = core_profiles_t.ni.value * nref true_ni_face = core_profiles_t.ni.face_value() * nref @@ -90,14 +90,14 @@ def __call__( # Diffusion term coefficient assert isinstance( - dynamic_config_slice_t.transport, + dynamic_runtime_params_slice_t.transport, constant_transport_model.DynamicRuntimeParams, ) d_face_ion = ( geo.g1_over_vpr_face * true_ni_face * consts.keV2J - * dynamic_config_slice_t.transport.chii_const + * dynamic_runtime_params_slice_t.transport.chii_const / geo.rmax**2 ) @@ -119,7 +119,7 @@ def __call__( # Update the potentially time-dependent boundary conditions as well. updated_boundary_conditions = ( core_profile_setters.compute_boundary_conditions( - dynamic_config_slice_t_plus_dt, + dynamic_runtime_params_slice_t_plus_dt, geo, ) ) @@ -133,7 +133,7 @@ def __call__( geo=geo, psi=core_profiles_t.psi, jtot_face=core_profiles_t.currents.jtot, - q_correction_factor=dynamic_config_slice_t.numerics.q_correction_factor, + q_correction_factor=dynamic_runtime_params_slice_t.numerics.q_correction_factor, ) s_face = physics.calc_s_from_psi(geo, core_profiles_t.psi) diff --git a/torax/tests/test_lib/sim_test_case.py b/torax/tests/test_lib/sim_test_case.py index 2d992e0b..b4c21662 100644 --- a/torax/tests/test_lib/sim_test_case.py +++ b/torax/tests/test_lib/sim_test_case.py @@ -25,11 +25,11 @@ import jax.numpy as jnp import numpy as np import torax -from torax import config as config_lib -from torax import config_slice from torax import geometry from torax import sim as sim_lib from torax import state as state_lib +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.stepper import nonlinear_theta_method from torax.stepper import stepper as stepper_lib @@ -79,10 +79,10 @@ def _get_config_module( def _get_config( self, config_name: str, - ) -> config_lib.Config: + ) -> general_runtime_params.GeneralRuntimeParams: """Returns an input Config from the name given.""" config_module = self._get_config_module(config_name) - return config_module.get_config() + return config_module.get_runtime_params() def _get_geometry( self, @@ -90,8 +90,8 @@ def _get_geometry( ) -> geometry.Geometry: """Returns an input Config from the name given.""" config_module = self._get_config_module(config_name) - config = config_module.get_config() - return config_module.get_geometry(config) + runtime_params = config_module.get_runtime_params() + return config_module.get_geometry(runtime_params) def _get_sim(self, config_name: str) -> sim_lib.Sim: """Returns a Sim given the name of a py file to build it.""" @@ -254,8 +254,8 @@ def _test_torax_sim( time_step_calculator=time_step_calculator, initial_state=sim.initial_state, geometry_provider=sim.geometry_provider, - dynamic_config_slice_provider=sim.dynamic_config_slice_provider, - static_config_slice=sim.static_config_slice, + dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider, + static_runtime_params_slice=sim.static_runtime_params_slice, stepper=sim.stepper, transport_model=sim.transport_model, step_fn=sim.step_fn, @@ -278,7 +278,7 @@ def _test_torax_sim( def make_frozen_optimizer_stepper( transport_model: transport_model_lib.TransportModel, source_models: source_models_lib.SourceModels, - config: config_lib.Config, + runtime_params: general_runtime_params.GeneralRuntimeParams, transport_params: transport_params_lib.RuntimeParams, ) -> stepper_lib.Stepper: """Makes an optimizer stepper with frozen coefficients. @@ -290,21 +290,23 @@ def make_frozen_optimizer_stepper( transport_model: Transport model. source_models: TORAX sources/sinks used to compute profile terms in the state evolution equations. - config: General TORAX config. + runtime_params: General TORAX runtime input parameters. transport_params: Runtime params for the transport model. Returns: Stepper: the stepper. """ - # Get the dynamic config for the start of the simulation. - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - transport=transport_params, - sources=source_models.runtime_params, + # Get the dynamic runtime params for the start of the simulation. + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + transport=transport_params, + sources=source_models.runtime_params, + ) ) callback_builder = functools.partial( sim_lib.FrozenCoeffsCallback, - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, ) return nonlinear_theta_method.OptimizerThetaMethod( transport_model, @@ -316,7 +318,7 @@ def make_frozen_optimizer_stepper( def make_frozen_newton_raphson_stepper( transport_model: transport_model_lib.TransportModel, source_models: source_models_lib.SourceModels, - config: config_lib.Config, + runtime_params: general_runtime_params.GeneralRuntimeParams, ) -> stepper_lib.Stepper: """Makes a Newton Raphson stepper with frozen coefficients. @@ -328,18 +330,20 @@ def make_frozen_newton_raphson_stepper( transport_model: Transport model. source_models: TORAX sources/sinks used to compute profile terms in the state evolution equations. - config: General TORAX config. + runtime_params: General TORAX runtime input parameters. Returns: Stepper: the stepper. """ - # Get the dynamic config for the start of the simulation. - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) + # Get the dynamic runtime params for the start of the simulation. + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice(runtime_params) + ) callback_builder = functools.partial( sim_lib.FrozenCoeffsCallback, - dynamic_config_slice=dynamic_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, ) - functools.partial(sim_lib.FrozenCoeffsCallback, config=config) + functools.partial(sim_lib.FrozenCoeffsCallback, runtime_params=runtime_params) return nonlinear_theta_method.NewtonRaphsonThetaMethod( transport_model, source_models=source_models, diff --git a/torax/tests/test_lib/torax_refs.py b/torax/tests/test_lib/torax_refs.py index a08a1710..b225c763 100644 --- a/torax/tests/test_lib/torax_refs.py +++ b/torax/tests/test_lib/torax_refs.py @@ -21,12 +21,13 @@ from jax import numpy as jnp import numpy as np import torax +from torax.config import config_args +from torax.config import runtime_params as general_runtime_params _GEO_DIRECTORY = 'torax/data/third_party/geo' # It's best to import the parent `torax` package because that has the # __init__ file that configures jax to float64 -config_lib = torax.config fvm = torax.fvm geometry = torax.geometry @@ -35,7 +36,7 @@ class References: """Collection of reference values useful for unit tests.""" - config: config_lib.Config + runtime_params: general_runtime_params.GeneralRuntimeParams geo: geometry.Geometry psi: fvm.CellVariable psi_face_grad: np.ndarray @@ -47,9 +48,9 @@ def circular_references() -> References: """Reference values for circular geometry.""" # Hard-code the parameters relevant to the tests, so the reference values # will stay valid even if we change the Config constructor defaults - config = config_lib.Config() - config = config_lib.recursive_replace( - config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + runtime_params = config_args.recursive_replace( + runtime_params, **{ 'profile_conditions': { 'Ip': 15, @@ -62,7 +63,7 @@ def circular_references() -> References: }, ) geo = geometry.build_circular_geometry( - config=config, + runtime_params=runtime_params, kappa=1.72, hires_fac=4, Rmaj=6.2, @@ -188,7 +189,7 @@ def circular_references() -> References: 2.3139498913449468, ]) return References( - config=config, + runtime_params=runtime_params, geo=geo, psi=psi, psi_face_grad=psi_face_grad, @@ -199,9 +200,9 @@ def circular_references() -> References: def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid-name """Reference values for CHEASE geometry where the Ip comes from the file.""" - config = config_lib.Config() - config = config_lib.recursive_replace( - config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + runtime_params = config_args.recursive_replace( + runtime_params, **{ 'profile_conditions': { 'Ip': 15, @@ -214,7 +215,7 @@ def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid }, ) geo = geometry.build_chease_geometry( - config=config, + runtime_params=runtime_params, geometry_dir=_GEO_DIRECTORY, geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=False, @@ -341,7 +342,7 @@ def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid 1.8139935034456611, ]) return References( - config=config, + runtime_params=runtime_params, geo=geo, psi=psi, psi_face_grad=psi_face_grad, @@ -350,11 +351,11 @@ def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid ) -def chease_references_Ip_from_config() -> References: # pylint: disable=invalid-name +def chease_references_Ip_from_runtime_params() -> References: # pylint: disable=invalid-name """Reference values for CHEASE geometry where the Ip comes from the config.""" - config = config_lib.Config() - config = config_lib.recursive_replace( - config, + runtime_params = general_runtime_params.GeneralRuntimeParams() + runtime_params = config_args.recursive_replace( + runtime_params, **{ 'profile_conditions': { 'Ip': 15, @@ -367,7 +368,7 @@ def chease_references_Ip_from_config() -> References: # pylint: disable=invalid }, ) geo = geometry.build_chease_geometry( - config=config, + runtime_params=runtime_params, geometry_dir=_GEO_DIRECTORY, geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=True, @@ -494,7 +495,7 @@ def chease_references_Ip_from_config() -> References: # pylint: disable=invalid 1.813993503445698, ]) return References( - config=config, + runtime_params=runtime_params, geo=geo, psi=psi, psi_face_grad=psi_face_grad, @@ -516,7 +517,7 @@ def setUp(self): chease_references_Ip_from_chease() ) self.chease_references_with_Ip_from_config = ( - chease_references_Ip_from_config() + chease_references_Ip_from_runtime_params() ) # pylint: enable=invalid-name diff --git a/torax/time_step_calculator/array_time_step_calculator.py b/torax/time_step_calculator/array_time_step_calculator.py index ce2cd8d5..0d631511 100644 --- a/torax/time_step_calculator/array_time_step_calculator.py +++ b/torax/time_step_calculator/array_time_step_calculator.py @@ -21,9 +21,9 @@ import jax from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import state as state_module +from torax.config import runtime_params_slice from torax.time_step_calculator import time_step_calculator State = int @@ -42,24 +42,32 @@ def initial_state(self) -> State: def not_done( self, t: Union[float, jax.Array], - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, state: State, ) -> Union[jax.Array, bool]: """Returns True until the whole array has been visited, then False.""" - del t, dynamic_config_slice # Unused for this type of TimeStepCalculator. + del ( + t, + dynamic_runtime_params_slice, + ) # Unused for this type of TimeStepCalculator. idx = state return idx < self.arr.shape[0] - 1 def next_dt( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state_module.CoreProfiles, time_step_calculator_state: State, core_transport: state_module.CoreTransport, ) -> tuple[jax.Array, State]: """Returns the next diff between consecutive array entries.""" - del dynamic_config_slice, geo, core_profiles, core_transport # Unused. + del ( + dynamic_runtime_params_slice, + geo, + core_profiles, + core_transport, + ) # Unused. idx = time_step_calculator_state idx += 1 return self.arr[idx] - self.arr[idx - 1], idx diff --git a/torax/time_step_calculator/chi_time_step_calculator.py b/torax/time_step_calculator/chi_time_step_calculator.py index 5b188e7d..40e85e05 100644 --- a/torax/time_step_calculator/chi_time_step_calculator.py +++ b/torax/time_step_calculator/chi_time_step_calculator.py @@ -22,10 +22,10 @@ import jax from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import jax_utils from torax import state as state_module +from torax.config import runtime_params_slice from torax.time_step_calculator import time_step_calculator # Dummy state and type for compatibility with time_step_calculator base class @@ -48,16 +48,16 @@ def initial_state(self): def not_done( self, t: Union[float, jax.Array], - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, state: State, ) -> Union[bool, jax.Array]: - """Returns True if iteration not done (t < config.numerics.t_final).""" - return t < dynamic_config_slice.numerics.t_final + """Returns True if iteration not done (t < runtime_params.numerics.t_final).""" + return t < dynamic_runtime_params_slice.numerics.t_final @functools.partial(jax_utils.jit, static_argnames=['self']) def next_dt( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state_module.CoreProfiles, time_step_calculator_state: State, @@ -69,8 +69,8 @@ def next_dt( size for the explicit method, and is therefore a function of chi_max. Args: - dynamic_config_slice: Input config parameters that can change without - triggering a JAX recompilation. + dynamic_runtime_params_slice: Input runtime parameters that can change + without triggering a JAX recompilation. geo: Geometry for the tokamak being simulated. core_profiles: Current core plasma profiles. time_step_calculator_state: None, for compatibility with @@ -86,8 +86,8 @@ def next_dt( basic_dt = (3.0 / 4.0) * (geo.dr_norm**2) / chi_max * geo.rmax**2 dt = jnp.minimum( - dynamic_config_slice.numerics.dtmult * basic_dt, - dynamic_config_slice.numerics.maxdt, + dynamic_runtime_params_slice.numerics.dtmult * basic_dt, + dynamic_runtime_params_slice.numerics.maxdt, ) return dt, STATE diff --git a/torax/time_step_calculator/fixed_time_step_calculator.py b/torax/time_step_calculator/fixed_time_step_calculator.py index ea0f0d91..eb0cc649 100644 --- a/torax/time_step_calculator/fixed_time_step_calculator.py +++ b/torax/time_step_calculator/fixed_time_step_calculator.py @@ -21,9 +21,9 @@ import jax from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import state as state_module +from torax.config import runtime_params_slice from torax.time_step_calculator import time_step_calculator # Dummy state and type for compatibility with time_step_calculator base class @@ -44,15 +44,15 @@ def initial_state(self): def not_done( self, t: Union[float, jax.Array], - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, state: State, ) -> Union[bool, jax.Array]: - """Returns True if iteration not done (t < config.numerics.t_final).""" - return t < dynamic_config_slice.numerics.t_final + """Returns True if iteration not done (t < runtime_params.numerics.t_final).""" + return t < dynamic_runtime_params_slice.numerics.t_final def next_dt( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state_module.CoreProfiles, time_step_calculator_state: State, @@ -61,8 +61,8 @@ def next_dt( """Calculates the next time step duration. Args: - dynamic_config_slice: Input config parameters that can change without - triggering a JAX recompilation. + dynamic_runtime_params_slice: Input runtime parameters that can change + without triggering a JAX recompilation. geo: Geometry for the tokamak being simulated. core_profiles: Current core plasma profiles. time_step_calculator_state: None, for compatibility with @@ -73,6 +73,6 @@ def next_dt( dt: Scalar time step duration. """ - dt = jnp.array(dynamic_config_slice.numerics.fixed_dt) + dt = jnp.array(dynamic_runtime_params_slice.numerics.fixed_dt) return dt, STATE diff --git a/torax/time_step_calculator/time_step_calculator.py b/torax/time_step_calculator/time_step_calculator.py index 38e27694..588d40de 100644 --- a/torax/time_step_calculator/time_step_calculator.py +++ b/torax/time_step_calculator/time_step_calculator.py @@ -22,9 +22,9 @@ import jax from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import state as state_module +from torax.config import runtime_params_slice # Subclasses override with their own state type State = TypeVar('State') @@ -52,7 +52,7 @@ def initial_state(self) -> State: def not_done( self, t: Union[float, jax.Array], - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, state: State, ) -> Union[bool, jax.Array]: """If True, next_dt may be called again.""" @@ -60,7 +60,7 @@ def not_done( @abc.abstractmethod def next_dt( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state_module.CoreProfiles, time_step_calculator_state: State, @@ -69,8 +69,8 @@ def next_dt( """Returns the next time step duration and internal time stepper state. Args: - dynamic_config_slice: Input config parameters that can change without - triggering a JAX recompilation. + dynamic_runtime_params_slice: Input runtime parameters that can change + without triggering a JAX recompilation. geo: Geometry for the Tokamak. core_profiles: Core plasma profiles in the tokamak. time_step_calculator_state: Internal state of the time stepper. diff --git a/torax/transport_model/constant.py b/torax/transport_model/constant.py index f2351f5b..4a279b05 100644 --- a/torax/transport_model/constant.py +++ b/torax/transport_model/constant.py @@ -21,11 +21,11 @@ import chex from jax import numpy as jnp -from torax import config_slice from torax import geometry from torax import jax_utils from torax import state -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params_slice from torax.transport_model import runtime_params as runtime_params_lib from torax.transport_model import transport_model @@ -49,7 +49,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -90,7 +90,7 @@ def sanity_check(self): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -117,30 +117,35 @@ def runtime_params(self, runtime_params: RuntimeParams) -> None: def _call_implementation( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> state.CoreTransport: del core_profiles # Not needed for this transport model - assert isinstance(dynamic_config_slice.transport, DynamicRuntimeParams) + assert isinstance( + dynamic_runtime_params_slice.transport, DynamicRuntimeParams + ) - chi_face_ion = dynamic_config_slice.transport.chii_const * jnp.ones_like( - geo.r_face + chi_face_ion = ( + dynamic_runtime_params_slice.transport.chii_const + * jnp.ones_like(geo.r_face) ) - chi_face_el = dynamic_config_slice.transport.chie_const * jnp.ones_like( - geo.r_face + chi_face_el = ( + dynamic_runtime_params_slice.transport.chie_const + * jnp.ones_like(geo.r_face) ) - d_face_el = dynamic_config_slice.transport.De_const * jnp.ones_like( + d_face_el = dynamic_runtime_params_slice.transport.De_const * jnp.ones_like( geo.r_face ) v_face_el = jnp.where( jnp.logical_and( - dynamic_config_slice.profile_conditions.set_pedestal, - geo.r_face_norm > dynamic_config_slice.profile_conditions.Ped_top, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + geo.r_face_norm + > dynamic_runtime_params_slice.profile_conditions.Ped_top, ), 0, - dynamic_config_slice.transport.Ve_const, + dynamic_runtime_params_slice.transport.Ve_const, ) return state.CoreTransport( diff --git a/torax/transport_model/critical_gradient.py b/torax/transport_model/critical_gradient.py index 35cec25a..066b6143 100644 --- a/torax/transport_model/critical_gradient.py +++ b/torax/transport_model/critical_gradient.py @@ -18,12 +18,12 @@ import chex from jax import numpy as jnp -from torax import config_slice from torax import constants as constants_module from torax import geometry from torax import jax_utils from torax import state -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params_slice from torax.transport_model import runtime_params as runtime_params_lib from torax.transport_model import transport_model @@ -46,7 +46,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -96,15 +96,15 @@ def runtime_params(self, runtime_params: RuntimeParams) -> None: def _call_implementation( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> state.CoreTransport: """Calculates transport coefficients using the Critical Gradient Model. Args: - dynamic_config_slice: Input config parameters that can change without - triggering a JAX recompilation. + dynamic_runtime_params_slice: Input runtime parameters that can change + without triggering a JAX recompilation. geo: Geometry of the torus. core_profiles: Core plasma profiles. @@ -121,7 +121,9 @@ def _call_implementation( # R/LTi_crit)*(R/LTi - R/LTi_crit)^alpha constants = constants_module.CONSTANTS - assert isinstance(dynamic_config_slice.transport, DynamicRuntimeParams) + assert isinstance( + dynamic_runtime_params_slice.transport, DynamicRuntimeParams + ) # set typical values for now. Will include user-defined q and s later s = core_profiles.s_face @@ -149,7 +151,8 @@ def _call_implementation( # gyrobohm diffusivity chiGB = ( - (dynamic_config_slice.plasma_composition.Ai * constants.mp) ** 0.5 + (dynamic_runtime_params_slice.plasma_composition.Ai * constants.mp) + ** 0.5 / (constants.qe * geo.B0) ** 2 * (temp_ion_face * constants.keV2J) ** 1.5 / geo.Rmaj @@ -159,7 +162,7 @@ def _call_implementation( rlti = -geo.Rmaj * temp_ion_face_grad / temp_ion_face # set minimum chi for PDE stability - chi_ion = dynamic_config_slice.transport.chimin * jnp.ones_like( + chi_ion = dynamic_runtime_params_slice.transport.chimin * jnp.ones_like( geo.mesh.face_centers ) @@ -167,16 +170,16 @@ def _call_implementation( chi_ion = jnp.where( rlti >= rlti_crit, chiGB - * dynamic_config_slice.transport.CGMchistiff - * (rlti - rlti_crit) ** dynamic_config_slice.transport.CGMalpha, + * dynamic_runtime_params_slice.transport.CGMchistiff + * (rlti - rlti_crit) ** dynamic_runtime_params_slice.transport.CGMalpha, chi_ion, ) # set (high) ceiling to CGM flux for PDE stability # (might not be necessary with Perezerev) chi_ion = jnp.where( - chi_ion > dynamic_config_slice.transport.chimax, - dynamic_config_slice.transport.chimax, + chi_ion > dynamic_runtime_params_slice.transport.chimax, + dynamic_runtime_params_slice.transport.chimax, chi_ion, ) @@ -184,18 +187,23 @@ def _call_implementation( # (more consistency between desired profile and transport coefficients) chi_face_ion = jnp.where( jnp.logical_and( - dynamic_config_slice.profile_conditions.set_pedestal, - geo.r_face_norm >= dynamic_config_slice.profile_conditions.Ped_top, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + geo.r_face_norm + >= dynamic_runtime_params_slice.profile_conditions.Ped_top, ), - dynamic_config_slice.transport.chimin, + dynamic_runtime_params_slice.transport.chimin, chi_ion, ) # set electron heat transport coefficient to user-defined ratio of ion heat # transport coefficient - chi_face_el = chi_face_ion / dynamic_config_slice.transport.CGMchiei_ratio + chi_face_el = ( + chi_face_ion / dynamic_runtime_params_slice.transport.CGMchiei_ratio + ) - d_face_el = chi_face_ion / dynamic_config_slice.transport.CGM_D_ratio + d_face_el = ( + chi_face_ion / dynamic_runtime_params_slice.transport.CGM_D_ratio + ) # No convection in this critical gradient model. # (Not a realistic model for particle transport anyway). diff --git a/torax/transport_model/qlknn_wrapper.py b/torax/transport_model/qlknn_wrapper.py index 331a950d..5d0aa783 100644 --- a/torax/transport_model/qlknn_wrapper.py +++ b/torax/transport_model/qlknn_wrapper.py @@ -29,13 +29,13 @@ import chex import jax from jax import numpy as jnp -from torax import config_slice from torax import constants as constants_module from torax import geometry from torax import jax_utils from torax import physics from torax import state -from torax.runtime_params import config_slice_args +from torax.config import config_args +from torax.config import runtime_params_slice from torax.transport_model import base_qlknn_model from torax.transport_model import qlknn_10d from torax.transport_model import runtime_params as runtime_params_lib @@ -74,7 +74,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, @@ -150,9 +150,9 @@ def _get_model() -> base_qlknn_model.BaseQLKNNModel: class _QLKNNRuntimeConfigInputs: """Runtime config inputs for QLKNN. - The runtime DynamicConfigSlice contains global config parameters, not all of - which are cacheable. This set of inputs IS cacheable, and using this added - layer allows the global config to change without affecting how + The runtime DynamicRuntimeParamsSlice contains global runtime parameters, not + all of which are cacheable. This set of inputs IS cacheable, and using this + added layer allows the global config to change without affecting how QLKNNTransportModel works. """ @@ -167,18 +167,20 @@ class _QLKNNRuntimeConfigInputs: # pylint: enable=invalid-name @staticmethod - def from_config_slice( - dynamic_config_slice: config_slice.DynamicConfigSlice, + def from_runtime_params_slice( + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, ) -> '_QLKNNRuntimeConfigInputs': - assert isinstance(dynamic_config_slice.transport, DynamicRuntimeParams) + assert isinstance( + dynamic_runtime_params_slice.transport, DynamicRuntimeParams + ) return _QLKNNRuntimeConfigInputs( - nref=dynamic_config_slice.numerics.nref, - Ai=dynamic_config_slice.plasma_composition.Ai, - Zeff=dynamic_config_slice.plasma_composition.Zeff, - transport=dynamic_config_slice.transport, - Ped_top=dynamic_config_slice.profile_conditions.Ped_top, - set_pedestal=dynamic_config_slice.profile_conditions.set_pedestal, - q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, + nref=dynamic_runtime_params_slice.numerics.nref, + Ai=dynamic_runtime_params_slice.plasma_composition.Ai, + Zeff=dynamic_runtime_params_slice.plasma_composition.Zeff, + transport=dynamic_runtime_params_slice.transport, + Ped_top=dynamic_runtime_params_slice.profile_conditions.Ped_top, + set_pedestal=dynamic_runtime_params_slice.profile_conditions.set_pedestal, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, ) @@ -230,15 +232,15 @@ def runtime_params(self, runtime_params: RuntimeParams) -> None: def _call_implementation( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> state.CoreTransport: """Calculates several transport coefficients simultaneously. Args: - dynamic_config_slice: Input config parameters that can change without - triggering a JAX recompilation. + dynamic_runtime_params_slice: Input runtime parameters that can change + without triggering a JAX recompilation. geo: Geometry of the torus. core_profiles: Core plasma profiles. @@ -258,8 +260,8 @@ def _call_implementation( # hashable. We assume that either we're running a whole sim in uncompiled # mode and everything is concrete or we're running a whole sim in compiled # mode and everything is a tracer, so we can just test one value. - runtime_config_inputs = _QLKNNRuntimeConfigInputs.from_config_slice( - dynamic_config_slice + runtime_config_inputs = _QLKNNRuntimeConfigInputs.from_runtime_params_slice( + dynamic_runtime_params_slice ) try: return self._cached_combined(runtime_config_inputs, geo, core_profiles) @@ -286,7 +288,7 @@ def _combined( `__call__` itself is just a cache dispatch wrapper. Args: - runtime_config_inputs: Input config parameters that can change without + runtime_config_inputs: Input runtime parameters that can change without triggering a JAX recompilation. geo: Geometry of the torus. core_profiles: Core plasma profiles. @@ -557,7 +559,7 @@ def Dscaled_approach() -> tuple[jnp.ndarray, jnp.ndarray]: # set low transport in pedestal region to facilitate PDE solver # (more consistency between desired profile and transport coefficients) - # if config.profile_conditions.set_pedestal: + # if runtime_params.profile_conditions.set_pedestal: mask = geo.r_face_norm >= runtime_config_inputs.Ped_top chi_face_ion = jnp.where( jnp.logical_and(runtime_config_inputs.set_pedestal, mask), diff --git a/torax/transport_model/runtime_params.py b/torax/transport_model/runtime_params.py index e2bc80f2..dc10eaae 100644 --- a/torax/transport_model/runtime_params.py +++ b/torax/transport_model/runtime_params.py @@ -23,7 +23,7 @@ import chex from torax import interpolated_param from torax import jax_utils -from torax.runtime_params import config_slice_args +from torax.config import config_args # Type-alias for clarity. While the InterpolatedParams can vary across any @@ -73,7 +73,7 @@ class RuntimeParams: def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams( - **config_slice_args.get_init_kwargs( + **config_args.get_init_kwargs( input_config=self, output_type=DynamicRuntimeParams, t=t, diff --git a/torax/transport_model/tests/qlknn_wrapper.py b/torax/transport_model/tests/qlknn_wrapper.py index 45a51a81..8c716d35 100644 --- a/torax/transport_model/tests/qlknn_wrapper.py +++ b/torax/transport_model/tests/qlknn_wrapper.py @@ -17,10 +17,10 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from torax import config as config_lib -from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.transport_model import qlknn_wrapper @@ -35,22 +35,26 @@ def test_qlknn_wrapper_cache_works(self): qlknn = qlknn_wrapper.QLKNNTransportModel() # Caching only works when compiled. qlknn_jitted = jax.jit(qlknn) - config = config_lib.Config() - geo = geometry.build_circular_geometry(config) + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry(runtime_params) source_models = source_models_lib.SourceModels() - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config=config, - transport=qlknn.runtime_params, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params=runtime_params, + transport=qlknn.runtime_params, + sources=source_models.runtime_params, + ) + ) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) ) - static_config_slice = config_slice.build_static_config_slice(config) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=static_config_slice, - dynamic_config_slice=dynamic_config_slice, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, source_models=source_models, ) - qlknn_jitted(dynamic_config_slice, geo, core_profiles) + qlknn_jitted(dynamic_runtime_params_slice, geo, core_profiles) # The call should be cached. If there was an error, the cache size would be # 0. self.assertGreaterEqual( diff --git a/torax/transport_model/tests/transport_model.py b/torax/transport_model/tests/transport_model.py index 7ad8d140..f78609e7 100644 --- a/torax/transport_model/tests/transport_model.py +++ b/torax/transport_model/tests/transport_model.py @@ -17,11 +17,11 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np -from torax import config as config_lib -from torax import config_slice from torax import geometry from torax import sim as sim_lib from torax import state +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice from torax.sources import source_models as source_models_lib from torax.time_step_calculator import fixed_time_step_calculator from torax.transport_model import runtime_params as runtime_params_lib @@ -34,10 +34,12 @@ class TransportSmoothingTest(parameterized.TestCase): def test_smoothing(self): """Tests that smoothing works as expected.""" # Set up default config and geo - config = config_lib.Config( - profile_conditions=config_lib.ProfileConditions(set_pedestal=False), + runtime_params = general_runtime_params.GeneralRuntimeParams( + profile_conditions=general_runtime_params.ProfileConditions( + set_pedestal=False + ), ) - geo = geometry.build_circular_geometry(config) + geo = geometry.build_circular_geometry(runtime_params) source_models = source_models_lib.SourceModels() transport_model = FakeTransportModel( runtime_params=runtime_params_lib.RuntimeParams( @@ -48,32 +50,36 @@ def test_smoothing(self): smoothing_sigma=0.05, ) ) - dynamic_config_slice = config_slice.build_dynamic_config_slice( - config, - transport=transport_model.runtime_params, - sources=source_models.runtime_params, + dynamic_runtime_params_slice = ( + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + transport=transport_model.runtime_params, + sources=source_models.runtime_params, + ) + ) + static_runtime_params_slice = ( + runtime_params_slice.build_static_runtime_params_slice(runtime_params) ) - static_config_slice = config_slice.build_static_config_slice(config) time_calculator = fixed_time_step_calculator.FixedTimeStepCalculator() input_state = sim_lib.get_initial_state( - dynamic_config_slice=dynamic_config_slice, - static_config_slice=static_config_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, geo=geo, time_step_calculator=time_calculator, source_models=source_models, ) transport_coeffs = transport_model( - dynamic_config_slice, geo, input_state.core_profiles + dynamic_runtime_params_slice, geo, input_state.core_profiles ) chi_face_ion_orig = np.linspace(0.5, 2, geo.r_face_norm.shape[0]) chi_face_el_orig = np.linspace(0.25, 1, geo.r_face_norm.shape[0]) d_face_el_orig = np.linspace(2, 3, geo.r_face_norm.shape[0]) v_face_el_orig = np.linspace(-0.2, -2, geo.r_face_norm.shape[0]) inner_patch_idx = np.searchsorted( - geo.r_face_norm, dynamic_config_slice.transport.rho_inner + geo.r_face_norm, dynamic_runtime_params_slice.transport.rho_inner ) outer_patch_idx = np.searchsorted( - geo.r_face_norm, dynamic_config_slice.transport.rho_outer + geo.r_face_norm, dynamic_runtime_params_slice.transport.rho_outer ) # assert that the smoothing did not impact the zones inside/outside the @@ -120,7 +126,7 @@ def test_smoothing(self): smoothing_array = np.exp( -np.log(2) * (r_reduced - test_r) ** 2 - / (dynamic_config_slice.transport.smoothing_sigma**2 + eps) + / (dynamic_runtime_params_slice.transport.smoothing_sigma**2 + eps) ) smoothing_array /= np.sum(smoothing_array) smoothing_array = np.where( @@ -180,11 +186,11 @@ def runtime_params( def _call_implementation( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> state.CoreTransport: - del dynamic_config_slice, core_profiles # these are unused + del dynamic_runtime_params_slice, core_profiles # these are unused chi_face_ion = np.linspace(0.5, 2, geo.r_face_norm.shape[0]) chi_face_el = np.linspace(0.25, 1, geo.r_face_norm.shape[0]) d_face_el = np.linspace(2, 3, geo.r_face_norm.shape[0]) diff --git a/torax/transport_model/transport_model.py b/torax/transport_model/transport_model.py index bb4d4364..0aa3bf18 100644 --- a/torax/transport_model/transport_model.py +++ b/torax/transport_model/transport_model.py @@ -20,10 +20,10 @@ import abc import jax from jax import numpy as jnp -from torax import config_slice from torax import constants from torax import geometry from torax import state +from torax.config import runtime_params_slice from torax.transport_model import runtime_params as runtime_params_lib @@ -37,14 +37,16 @@ class TransportModel(abc.ABC): def __call__( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> state.CoreTransport: return self.smooth_coeffs( geo, - dynamic_config_slice, - self._call_implementation(dynamic_config_slice, geo, core_profiles), + dynamic_runtime_params_slice, + self._call_implementation( + dynamic_runtime_params_slice, geo, core_profiles + ), ) @property @@ -63,7 +65,7 @@ def runtime_params( @abc.abstractmethod def _call_implementation( self, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> state.CoreTransport: @@ -72,11 +74,11 @@ def _call_implementation( def smooth_coeffs( self, geo: geometry.Geometry, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, transport_coeffs: state.CoreTransport, ) -> state.CoreTransport: """Gaussian smoothing of transport coefficients.""" - smoothing_matrix = build_smoothing_matrix(geo, dynamic_config_slice) + smoothing_matrix = build_smoothing_matrix(geo, dynamic_runtime_params_slice) smoothed_coeffs = {} for coeff in transport_coeffs: smoothed_coeff = jnp.dot(smoothing_matrix, transport_coeffs[coeff]) @@ -86,7 +88,7 @@ def smooth_coeffs( def build_smoothing_matrix( geo: geometry.Geometry, - dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, ) -> jax.Array: """Builds a smoothing matrix for the transport model. @@ -94,8 +96,8 @@ def build_smoothing_matrix( Args: geo: Geometry of the torus. - dynamic_config_slice: Input config parameters that can change without - triggering a JAX recompilation. + dynamic_runtime_params_slice: Input runtime parameters that can change + without triggering a JAX recompilation. Returns: kernel: A smoothing matrix for convolution with the transport outputs. @@ -112,30 +114,33 @@ def build_smoothing_matrix( kernel = jnp.exp( -jnp.log(2) * (geo.r_face_norm[:, jnp.newaxis] - geo.r_face_norm) ** 2 - / (dynamic_config_slice.transport.smoothing_sigma**2 + consts.eps) + / (dynamic_runtime_params_slice.transport.smoothing_sigma**2 + consts.eps) ) # 2. Masking: we do not want transport coefficients calculated in pedestal # region or in inner and outer transport patch regions, to impact # transport_model calculated coefficients mask_outer_edge = jax.lax.cond( - dynamic_config_slice.profile_conditions.set_pedestal, - lambda: dynamic_config_slice.profile_conditions.Ped_top - consts.eps, + dynamic_runtime_params_slice.profile_conditions.set_pedestal, + lambda: dynamic_runtime_params_slice.profile_conditions.Ped_top + - consts.eps, lambda: 1.0, ) mask_outer_edge = jax.lax.cond( jnp.logical_and( - jnp.logical_not(dynamic_config_slice.profile_conditions.set_pedestal), - dynamic_config_slice.transport.apply_outer_patch, + jnp.logical_not( + dynamic_runtime_params_slice.profile_conditions.set_pedestal + ), + dynamic_runtime_params_slice.transport.apply_outer_patch, ), - lambda: dynamic_config_slice.transport.rho_outer - consts.eps, + lambda: dynamic_runtime_params_slice.transport.rho_outer - consts.eps, lambda: mask_outer_edge, ) mask_inner_edge = jax.lax.cond( - dynamic_config_slice.transport.apply_inner_patch, - lambda: dynamic_config_slice.transport.rho_inner + consts.eps, + dynamic_runtime_params_slice.transport.apply_inner_patch, + lambda: dynamic_runtime_params_slice.transport.rho_inner + consts.eps, lambda: 0.0, )