From 3b63368e6978cf97e815cada113da3732950f647 Mon Sep 17 00:00:00 2001 From: Akhil Raju Date: Thu, 25 Apr 2024 15:48:06 -0700 Subject: [PATCH] Change the API and runtime configuration for all sources. This moves the source runtime parameter configuration out of `Config` and into the individual sources. This also cleans up some old issues with the Source API around naming. PiperOrigin-RevId: 628214715 --- .github/workflows/pytest.yml | 1 - run_simulation_main.py | 14 +- torax/__init__.py | 3 +- torax/calc_coeffs.py | 44 +- torax/config.py | 71 +-- torax/config_slice.py | 159 +----- torax/core_profile_setters.py | 107 ++-- torax/fvm/newton_raphson_solve_block.py | 1 - torax/fvm/optimizer_solve_block.py | 1 - torax/fvm/tests/fvm.py | 81 +-- torax/runtime_params/config_slice_args.py | 9 +- torax/sim.py | 47 +- torax/simulation_app.py | 42 +- torax/sources/bootstrap_current_source.py | 327 ++++++------- torax/sources/current_density_sources.py | 18 +- torax/sources/default_sources.py | 118 +++++ torax/sources/electron_density_sources.py | 233 ++++++--- torax/sources/external_current_source.py | 165 +++++-- torax/sources/formula_config.py | 69 ++- torax/sources/formulas.py | 38 +- torax/sources/fusion_heat_source.py | 20 +- torax/sources/generic_ion_el_heat_source.py | 68 ++- torax/sources/ion_el_heat_sources.py | 40 +- torax/sources/qei_source.py | 60 ++- .../{source_config.py => runtime_params.py} | 81 ++- torax/sources/source.py | 320 ++++++------ torax/sources/source_models.py | 460 +++++++++--------- .../sources/tests/bootstrap_current_source.py | 28 +- .../sources/tests/current_density_sources.py | 18 +- .../sources/tests/electron_density_sources.py | 18 +- .../sources/tests/external_current_source.py | 31 +- torax/sources/tests/formulas.py | 141 +++--- torax/sources/tests/fusion_heat_source.py | 19 +- .../tests/generic_ion_el_heat_source.py | 6 +- torax/sources/tests/ion_el_heat_sources.py | 42 +- torax/sources/tests/qei_source.py | 55 ++- torax/sources/tests/source.py | 260 +++++----- torax/sources/tests/source_config.py | 54 -- torax/sources/tests/source_models.py | 57 +-- torax/sources/tests/test_lib.py | 136 +++--- torax/spectators/tests/plotting.py | 2 + torax/stepper/linear_theta_method.py | 1 - torax/stepper/stepper.py | 1 - torax/tests/boundary_conditions.py | 7 +- torax/tests/config_slice.py | 189 +++---- torax/tests/physics.py | 12 +- torax/tests/sim.py | 3 + torax/tests/sim_custom_sources.py | 261 ++++++---- torax/tests/sim_output_source_profiles.py | 114 +++-- torax/tests/sim_time_dependence.py | 2 +- torax/tests/state.py | 117 +++-- .../tests/test_data/compilation_benchmark.py | 35 +- torax/tests/test_data/default_config.py | 9 + torax/tests/test_data/test_absolute_jext.py | 34 +- .../test_all_transport_crank_nicolson.py | 44 +- .../test_all_transport_fusion_qlknn.py | 38 +- torax/tests/test_data/test_bootstrap.py | 39 +- torax/tests/test_data/test_cgmheat.py | 28 +- torax/tests/test_data/test_chease.py | 45 +- torax/tests/test_data/test_crank_nicolson.py | 31 +- torax/tests/test_data/test_exact_finaltime.py | 28 +- torax/tests/test_data/test_explicit.py | 46 +- torax/tests/test_data/test_fixed_dt.py | 28 +- .../test_data/test_frozen_newton_raphson.py | 31 +- .../tests/test_data/test_frozen_optimizer.py | 31 +- torax/tests/test_data/test_fusion_power.py | 41 +- torax/tests/test_data/test_implicit.py | 31 +- .../test_implicit_short_optimizer.py | 31 +- .../test_data/test_iterbaseline_mockup.py | 127 +++-- .../tests/test_data/test_iterhybrid_mockup.py | 126 +++-- .../tests/test_data/test_iterhybrid_newton.py | 126 +++-- .../test_iterhybrid_predictor_corrector.py | 126 +++-- .../tests/test_data/test_iterhybrid_rampup.py | 121 +++-- .../test_data/test_ne_qlknn_deff_veff.py | 44 +- .../test_data/test_ne_qlknn_defromchie.py | 44 +- .../test_data/test_newton_raphson_zeroiter.py | 27 +- torax/tests/test_data/test_ohmic_power.py | 36 +- .../test_data/test_optimizer_zeroiter.py | 27 +- .../test_data/test_particle_sources_cgm.py | 39 +- .../test_particle_sources_constant.py | 39 +- torax/tests/test_data/test_pc_method_ne.py | 47 +- torax/tests/test_data/test_pedestal.py | 28 +- .../test_prescribed_timedependent_ne.py | 55 ++- torax/tests/test_data/test_psi_and_heat.py | 28 +- torax/tests/test_data/test_psi_heat_dens.py | 39 +- .../test_data/test_psichease_ip_chease.py | 45 +- .../test_data/test_psichease_ip_parameters.py | 45 +- .../test_psichease_prescribed_johm.py | 45 +- .../test_psichease_prescribed_jtot.py | 45 +- torax/tests/test_data/test_psiequation.py | 31 +- torax/tests/test_data/test_qei.py | 28 +- .../test_data/test_qei_chease_highdens.py | 42 +- torax/tests/test_data/test_qlknnheat.py | 28 +- .../test_data/test_semiimplicit_convection.py | 28 +- torax/tests/test_data/test_timedependence.py | 49 +- torax/tests/test_lib/explicit_stepper.py | 6 +- torax/tests/test_lib/sim_test_case.py | 1 + torax/tests/test_lib/torax_refs.py | 9 - torax/transport_model/qlknn_wrapper.py | 2 +- torax/transport_model/tests/qlknn_wrapper.py | 8 +- .../transport_model/tests/transport_model.py | 6 +- 101 files changed, 3620 insertions(+), 2618 deletions(-) create mode 100644 torax/sources/default_sources.py rename torax/sources/{source_config.py => runtime_params.py} (53%) delete mode 100644 torax/sources/tests/source_config.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index a378354a..bf93bc99 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -77,7 +77,6 @@ jobs: torax/sources/tests/generic_ion_el_heat_source.py \ torax/sources/tests/ion_el_heat_sources.py \ torax/sources/tests/qei_source.py \ - torax/sources/tests/source_config.py \ torax/sources/tests/source_models.py \ torax/sources/tests/source.py \ torax/spectators/tests/plotting.py \ diff --git a/run_simulation_main.py b/run_simulation_main.py index cf578b10..d6157ba9 100644 --- a/run_simulation_main.py +++ b/run_simulation_main.py @@ -201,6 +201,11 @@ def change_config( new_config = config_module.get_config() new_geo = config_module.get_geometry(new_config) new_transport_model = config_module.get_transport_model() + source_models = config_module.get_sources() + new_source_params = { + name: source.runtime_params + for name, source in source_models.sources.items() + } # Make sure the transport model has not changed. # TODO(b/330172917): Improve the check for updated configs. if not isinstance(new_transport_model, type(sim.transport_model)): @@ -210,10 +215,11 @@ def change_config( ' this option, you cannot change the transport model.' ) sim = simulation_app.update_sim( - sim, - new_config, - new_geo, - new_transport_model.runtime_params, + sim=sim, + config=new_config, + geo=new_geo, + transport_runtime_params=new_transport_model.runtime_params, + source_runtime_params=new_source_params, ) return sim, new_config, config_module_str diff --git a/torax/__init__.py b/torax/__init__.py index f6eb85f9..d8280491 100644 --- a/torax/__init__.py +++ b/torax/__init__.py @@ -74,6 +74,7 @@ 'dynamic_config_slice_t', 'dynamic_config_slice_t_plus_dt', 'unused_config', + 'dynamic_source_runtime_params', 'geo', 'x_old', 'state', @@ -81,11 +82,11 @@ 'core_profiles', 'psi', 'transport_model', - 'time_step_calculator', 'source_profiles', 'source_profile', 'explicit_source_profiles', 'source_models', + 'time_step_calculator', 'coeffs_callback', 'evolving_names', 'spectator', diff --git a/torax/calc_coeffs.py b/torax/calc_coeffs.py index 8b3127d2..89c35881 100644 --- a/torax/calc_coeffs.py +++ b/torax/calc_coeffs.py @@ -41,8 +41,12 @@ def calculate_pereverzev_flux( """Adds Pereverzev-Corrigan flux to diffusion terms.""" consts = constants.CONSTANTS - true_ne_face = core_profiles.ne.face_value() * dynamic_config_slice.nref - true_ni_face = core_profiles.ni.face_value() * dynamic_config_slice.nref + true_ne_face = ( + core_profiles.ne.face_value() * dynamic_config_slice.numerics.nref + ) + true_ni_face = ( + core_profiles.ni.face_value() * dynamic_config_slice.numerics.nref + ) geo_factor = jnp.concatenate( [jnp.ones(1), geo.g1_over_vpr_face[1:] / geo.g0_face[1:]] @@ -277,22 +281,22 @@ def _calc_coeffs_full( # Decide which values to use depending on whether the source is explicit or # implicit. sigma = jax_utils.select( - dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit, + dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, explicit_source_profiles.j_bootstrap.sigma, implicit_source_profiles.j_bootstrap.sigma, ) j_bootstrap = jax_utils.select( - dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit, + dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, explicit_source_profiles.j_bootstrap.j_bootstrap, implicit_source_profiles.j_bootstrap.j_bootstrap, ) j_bootstrap_face = jax_utils.select( - dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit, + dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, explicit_source_profiles.j_bootstrap.j_bootstrap_face, implicit_source_profiles.j_bootstrap.j_bootstrap_face, ) I_bootstrap = jax_utils.select( # pylint: disable=invalid-name - dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit, + dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit, explicit_source_profiles.j_bootstrap.I_bootstrap, implicit_source_profiles.j_bootstrap.I_bootstrap, ) @@ -328,13 +332,21 @@ def _calc_coeffs_full( source_models, ) - true_ne_face = core_profiles.ne.face_value() * dynamic_config_slice.nref - true_ni_face = core_profiles.ni.face_value() * dynamic_config_slice.nref + true_ne_face = ( + core_profiles.ne.face_value() * dynamic_config_slice.numerics.nref + ) + true_ni_face = ( + core_profiles.ni.face_value() * dynamic_config_slice.numerics.nref + ) # Transient term coefficient vector (has radial dependence through r, n) - toc_temp_ion = 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.nref + toc_temp_ion = ( + 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.numerics.nref + ) tic_temp_ion = core_profiles.ni.value - toc_temp_el = 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.nref + toc_temp_el = ( + 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.numerics.nref + ) tic_temp_el = core_profiles.ne.value toc_psi = ( 1.0 @@ -516,7 +528,7 @@ def _calc_coeffs_full( dynamic_config_slice.profile_conditions.Ip / (jnp.pi * geo.Rmin**2) * 1e20 - / dynamic_config_slice.nref + / dynamic_config_slice.numerics.nref ) # pylint: enable=invalid-name neped_unnorm = jnp.where( @@ -551,7 +563,9 @@ def _calc_coeffs_full( ) = jax.lax.cond( use_pereverzev, lambda: calculate_pereverzev_flux( - dynamic_config_slice, geo, core_profiles, + dynamic_config_slice, + geo, + core_profiles, ), lambda: tuple([jnp.zeros_like(geo.r_face)] * 6), ) @@ -563,9 +577,11 @@ def _calc_coeffs_full( # Ion and electron heat sources. qei = source_models.qei_source.get_qei( - dynamic_config_slice.sources[source_models.qei_source.name].source_type, - dynamic_config_slice=dynamic_config_slice, static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources[ + source_models.qei_source_name + ], geo=geo, # For Qei, always use the current set of core profiles. # In the linear solver, core_profiles is the set of profiles at time t (at diff --git a/torax/config.py b/torax/config.py index 4851251c..c084ea97 100644 --- a/torax/config.py +++ b/torax/config.py @@ -18,14 +18,13 @@ parameters. """ -from collections.abc import Iterable, Mapping +from collections.abc import Iterable import dataclasses import enum import typing from typing import Any import chex from torax import interpolated_param -from torax.sources import source_config # Type-alias for clarity. While the InterpolatedParams can vary across any @@ -205,10 +204,10 @@ class Numerics: # 1/multiplication factor for sigma (conductivity) to reduce current # diffusion timescale to be closer to heat diffusion timescale resistivity_mult: TimeDependentField = 1.0 - # Multiplication factor for bootstrap current - bootstrap_mult: float = 1.0 - # multiplier for ion-electron heat exchange term for sensitivity testing - Qei_mult: float = 1.0 + + # density profile info + # Reference value for normalization + nref: float = 1e20 # numerical (e.g. no. of grid points, other info needed by solver) # radial grid points (num cells) @@ -234,43 +233,6 @@ class Config: ) numerics: Numerics = dataclasses.field(default_factory=Numerics) - # TODO(b/330172917): Move the source parameters into `sources`. - - # density profile info - # Reference value for normalization - nref: float = 1e20 - - # external heat source parameters - w: TimeDependentField = 0.25 # Gaussian width in normalized radial coordinate - # Source Gaussian central location (in normalized r) - rsource: TimeDependentField = 0.0 - Ptot: TimeDependentField = 120e6 # total heating - el_heat_fraction: TimeDependentField = 0.66666 # electron heating fraction - - # particle source parameters - # Gaussian width of pellet deposition [normalized radial coord], - # (continuous pellet model) - pellet_width: TimeDependentField = 0.1 - # Pellet source Gaussian central location [normalized radial coord] - # (continuous pellet model) - pellet_deposition_location: TimeDependentField = 0.85 - # total pellet particles/s (continuous pellet model) - # TODO(b/326578331): improve numerical strategy, avoid these large numbers - S_pellet_tot: TimeDependentField = 2e22 - - # exponential decay length of gas puff ionization [normalized radial coord] - puff_decay_length: TimeDependentField = 0.05 - # total gas puff particles/s - # TODO(b/326578331): improve numerical strategy, avoid these large numbers - S_puff_tot: TimeDependentField = 1e22 - - # NBI particle source Gaussian width in normalized radial coord - nbi_particle_width: TimeDependentField = 0.25 - # NBI particle source Gaussian central location in normalized radial coord - nbi_deposition_location: TimeDependentField = 0.0 - # NBI total particle source - S_nbi_tot: TimeDependentField = 1e22 - # current profiles (broad "Ohmic" + localized "external" currents) # peaking factor of "Ohmic" current: johm = j0*(1 - r^2/a^2)^nu nu: float = 3.0 @@ -281,16 +243,6 @@ class Config: # or from the psi available in the numerical geometry file. This setting is # ignored for the ad-hoc circular geometry, which has no numerical geometry. initial_psi_from_j: bool = False - # toggles if external current is provided absolutely or as a fraction of Ip - use_absolute_jext: bool = False - # total "external" current in MA. Used if use_absolute_jext=True. - Iext: TimeDependentField = 3.0 - # total "external" current fraction. Used if use_absolute_jext=False. - fext: TimeDependentField = 0.2 - # width of "external" Gaussian current profile - wext: TimeDependentField = 0.05 - # normalized radius of "external" Gaussian current profile - rext: TimeDependentField = 0.4 # solver parameters solver: SolverConfig = dataclasses.field(default_factory=SolverConfig) @@ -301,13 +253,6 @@ class Config: # pylint: enable=invalid-name - # Runtime configs for all source/sink terms. - # Note that the sources field is overridden in the __post_init__. See impl for - # details on how this field is updated. - sources: Mapping[str, source_config.SourceConfig] = dataclasses.field( - default_factory=source_config.get_default_sources_config - ) - def sanity_check(self) -> None: """Checks that various configuration parameters are valid.""" # TODO(b/330172917) do more extensive config parameter sanity checking @@ -319,12 +264,6 @@ def sanity_check(self) -> None: assert isinstance(self.numerics, Numerics) def __post_init__(self): - # The sources config should have the default values from - # source_config.get_default_sources_config. The additional values provided - # via the config constructor should OVERRIDE these defaults. - sources = dict(source_config.get_default_sources_config()) - sources.update(self.sources) # Update with the user inputs. - self.sources = sources self.sanity_check() diff --git a/torax/config_slice.py b/torax/config_slice.py index 1c83e180..fafa5f48 100644 --- a/torax/config_slice.py +++ b/torax/config_slice.py @@ -42,8 +42,8 @@ import chex from torax import config as config_lib -from torax import jax_utils from torax.runtime_params import config_slice_args +from torax.sources import runtime_params as sources_params from torax.transport_model import runtime_params as transport_model_params @@ -76,41 +76,7 @@ class DynamicConfigSlice: plasma_composition: DynamicPlasmaComposition profile_conditions: DynamicProfileConditions numerics: DynamicNumerics - sources: Mapping[str, DynamicSourceConfigSlice] - - # density profile info - # Reference value for normalization - nref: float - - # external heat source parameters - w: float # Gaussian width - rsource: float # Source Gaussian central location - Ptot: float # total heating - el_heat_fraction: float # fraction of heating to electrons (rest are to ions) - - # particle source parameters - # Gaussian width of pellet deposition [normalized radial coord], - # (continuous pellet model) - pellet_width: float - # Pellet source Gaussian central location [normalized radial coord] - # (continuous pellet model) - pellet_deposition_location: float - # total pellet particles/s (continuous pellet model) - # TODO(b/326578331): improve numerical strategy, avoid these large numbers - S_pellet_tot: float - - # exponential decay length of gas puff ionization [normalized radial coord] - puff_decay_length: float - # total gas puff particles/s - # TODO(b/326578331): improve numerical strategy, avoid these large numbers - S_puff_tot: float - - # NBI particle source Gaussian width in normalized radial coord - nbi_particle_width: float - # NBI particle source Gaussian central location in normalized radial coord - nbi_deposition_location: float - # NBI total particle source - S_nbi_tot: float + sources: Mapping[str, sources_params.DynamicRuntimeParams] # current profiles (broad "Ohmic" + localized "external" currents) @@ -124,23 +90,6 @@ class DynamicConfigSlice: # or from the psi available in the numerical geometry file. This setting is # ignored for the ad-hoc circular geometry, which has no numerical geometry. initial_psi_from_j: bool - # toggles if external current is provided absolutely or as a fraction of Ip - use_absolute_jext: bool - # total "external" current in MA. Used if use_absolute_jext=True. - Iext: float - # total "external" current fraction. Used if use_absolute_jext=False. - fext: float - # width of "external" Gaussian current profile - wext: float - # normalized radius of "external" Gaussian current profile - rext: float - - def sanity_check(self): - """Checks that all parameters are valid.""" - jax_utils.error_if_negative(self.wext, 'wext') - - def __post_init__(self): - self.sanity_check() @chex.dataclass(frozen=True) @@ -249,10 +198,10 @@ class DynamicNumerics: # 1/multiplication factor for sigma (conductivity) to reduce current # diffusion timescale to be closer to heat diffusion timescale resistivity_mult: float - # Multiplication factor for bootstrap current - bootstrap_mult: float - # multiplier for ion-electron heat exchange term for sensitivity testing - Qei_mult: float + + # density profile info + # Reference value for normalization + nref: float # numerical (e.g. no. of grid points, other info needed by solver) # effective source to dominate PDE in internal boundary condtion location @@ -269,60 +218,6 @@ class DynamicNumerics: enable_prescribed_profile_evolution: bool -@chex.dataclass(frozen=True) -class DynamicExponentialFormulaConfigSlice: - """Runtime config for an exponential source profile.""" - - # floats to parameterize the formula. - total: float - c1: float - c2: float - # If True, uses r_norm when calculating the source profiles. - use_normalized_r: bool - - -@chex.dataclass(frozen=True) -class DynamicGaussianFormulaConfigSlice: - # floats to parameterize the formula. - total: float - c1: float - c2: float - # If True, uses r_norm when calculating the source profiles. - use_normalized_r: bool - - -@chex.dataclass(frozen=True) -class DynamicFormulaConfigSlice: - """Contains all formula configs.""" - - exponential: DynamicExponentialFormulaConfigSlice - gaussian: DynamicGaussianFormulaConfigSlice - custom_params: dict[str, chex.Numeric] - - -@chex.dataclass(frozen=True) -class DynamicSourceConfigSlice: - """Dynamic params for a single TORAX source. - - These params can be changed without triggering a recompile. TORAX sources are - stateless, so these params are their inputs to determine their output - profiles. - """ - - # Method to get the source profile. See source_config.py for more info on - # possible types. This maps to the enum value for the SourceType enum. The - # enum itself is not JAX-friendly. - source_type: int - # If True, this source depends on the mesh state at the start of the time - # step, or does not depend on the mesh state at all, to compute it's value - # for the time step. If False, then the source will depend on the "live" - # state that is updated within the JointStateStepper call. - is_explicit: bool - # Parameters used only when the source is using a prescribed formula to - # compute its profile. - formula: DynamicFormulaConfigSlice - - @chex.dataclass(frozen=True) class StaticConfigSlice: """Static arguments to JointStateStepper which cannot be changed. @@ -382,10 +277,12 @@ class StaticSolverConfigSlice: def build_dynamic_config_slice( config: config_lib.Config, transport: transport_model_params.RuntimeParams | None = None, + sources: dict[str, sources_params.RuntimeParams] | None = None, t: chex.Numeric | None = None, ) -> DynamicConfigSlice: """Builds a DynamicConfigSlice based on the input config.""" transport = transport or transport_model_params.RuntimeParams() + sources = sources or {} t = config.numerics.t_initial if t is None else t # For each dataclass attribute under DynamicConfigSlice, build those objects # explicitly, and then for all scalar attributes, fetch their values directly @@ -399,7 +296,7 @@ def build_dynamic_config_slice( t=t, ) ), - sources=_build_dynamic_sources(config, t), + sources=_build_dynamic_sources(sources, t), plasma_composition=DynamicPlasmaComposition( **config_slice_args.get_init_kwargs( input_config=config.plasma_composition, @@ -438,37 +335,14 @@ def build_dynamic_config_slice( def _build_dynamic_sources( - config: config_lib.Config, + sources: dict[str, sources_params.RuntimeParams], t: chex.Numeric, -) -> dict[str, DynamicSourceConfigSlice]: +) -> dict[str, sources_params.DynamicRuntimeParams]: """Builds a dict of DynamicSourceConfigSlice based on the input config.""" - source_configs = {} - for source_name, input_source_config in config.sources.items(): - source_configs[source_name] = DynamicSourceConfigSlice( - source_type=input_source_config.source_type.value, - is_explicit=input_source_config.is_explicit, - formula=DynamicFormulaConfigSlice( - exponential=DynamicExponentialFormulaConfigSlice( - **config_slice_args.get_init_kwargs( - input_config=input_source_config.formula.exponential, - output_type=DynamicExponentialFormulaConfigSlice, - t=t, - ) - ), - gaussian=DynamicGaussianFormulaConfigSlice( - **config_slice_args.get_init_kwargs( - input_config=input_source_config.formula.gaussian, - output_type=DynamicGaussianFormulaConfigSlice, - t=t, - ) - ), - custom_params={ - key: config_slice_args.interpolate_param(value, t) - for key, value in input_source_config.formula.custom_params.items() - }, - ), - ) - return source_configs + return { + source_name: input_source_config.build_dynamic_params(t) + for source_name, input_source_config in sources.items() + } def build_static_config_slice(config: config_lib.Config) -> StaticConfigSlice: @@ -504,9 +378,11 @@ def __init__( self, config: config_lib.Config, transport_getter: Callable[[], transport_model_params.RuntimeParams], + sources_getter: Callable[[], dict[str, sources_params.RuntimeParams]], ): self._input_config = config self._transport_runtime_params_getter = transport_getter + self._sources_getter = sources_getter if ( not self._input_config.profile_conditions.set_pedestal @@ -526,5 +402,6 @@ def __call__( return build_dynamic_config_slice( config=self._input_config, transport=self._transport_runtime_params_getter(), + sources=self._sources_getter(), t=t, ) diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index ab218692..5545bcce 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -111,7 +111,7 @@ def _updated_dens( dynamic_config_slice.profile_conditions.Ip / (jnp.pi * geo.Rmin**2) * 1e20 - / dynamic_config_slice.nref + / dynamic_config_slice.numerics.nref ) nbar_unnorm = jnp.where( dynamic_config_slice.profile_conditions.nbar_is_fGW, @@ -189,10 +189,11 @@ def _prescribe_currents_no_bootstrap( # Calculate splitting of currents depending on config Ip = dynamic_config_slice.profile_conditions.Ip - if dynamic_config_slice.use_absolute_jext: - Iext = dynamic_config_slice.Iext + dynamic_jext_params = _get_jext_params(dynamic_config_slice, source_models) + if dynamic_jext_params.use_absolute_jext: + Iext = dynamic_jext_params.Iext else: - Iext = Ip * dynamic_config_slice.fext + Iext = Ip * dynamic_jext_params.fext # Total Ohmic current Iohm = Ip - Iext @@ -203,10 +204,9 @@ def _prescribe_currents_no_bootstrap( # calculate "External" current profile (e.g. ECCD) # form of external current on face grid - jext_source = source_models.jext - jext_face, jext = jext_source.get_value( - source_type=dynamic_config_slice.sources[jext_source.name].source_type, + jext_face, jext = source_models.jext.get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_jext_params, geo=geo, ) @@ -227,7 +227,12 @@ def _prescribe_currents_no_bootstrap( johm = geometry.face_to_cell(johm_face) jtot_hires = _get_jtot_hires( - dynamic_config_slice, geo, bootstrap_profile, Iohm, jext_source + dynamic_config_slice, + dynamic_jext_params, + geo, + bootstrap_profile, + Iohm, + source_models.jext, ) currents = state.Currents( @@ -283,6 +288,9 @@ def _prescribe_currents_with_bootstrap( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources[ + source_models.j_bootstrap_name + ], geo=geo, temp_ion=temp_ion, temp_el=temp_el, @@ -294,19 +302,19 @@ def _prescribe_currents_with_bootstrap( f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6) # Calculate splitting of currents depending on config - if dynamic_config_slice.use_absolute_jext: - Iext = dynamic_config_slice.Iext + dynamic_jext_params = _get_jext_params(dynamic_config_slice, source_models) + if dynamic_jext_params.use_absolute_jext: + Iext = dynamic_jext_params.Iext else: - Iext = Ip * dynamic_config_slice.fext + Iext = Ip * dynamic_jext_params.fext Iohm = Ip - Iext - f_bootstrap * Ip # calculate "External" current profile (e.g. ECCD) # form of external current on face grid - jext_source = source_models.jext - jext_face, jext = jext_source.get_value( + jext_face, jext = source_models.jext.get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_jext_params, geo=geo, - source_type=dynamic_config_slice.sources[jext_source.name].source_type, ) # construct prescribed current formula on grid. @@ -326,7 +334,12 @@ def _prescribe_currents_with_bootstrap( johm = geometry.face_to_cell(johm_face) jtot_hires = _get_jtot_hires( - dynamic_config_slice, geo, bootstrap_profile, Iohm, jext_source + dynamic_config_slice, + dynamic_jext_params, + geo, + bootstrap_profile, + Iohm, + source_models.jext, ) currents = state.Currents( @@ -385,6 +398,9 @@ def _calculate_currents_from_psi( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources[ + source_models.j_bootstrap_name + ], geo=geo, temp_ion=temp_ion, temp_el=temp_el, @@ -395,27 +411,36 @@ def _calculate_currents_from_psi( ) # Calculate splitting of currents depending on config - if dynamic_config_slice.use_absolute_jext: - Iext = dynamic_config_slice.Iext + dynamic_jext_params = _get_jext_params(dynamic_config_slice, source_models) + if dynamic_jext_params.use_absolute_jext: + Iext = dynamic_jext_params.Iext else: - Iext = Ip * dynamic_config_slice.fext + Iext = Ip * dynamic_jext_params.fext Iohm = Ip - Iext - bootstrap_profile.I_bootstrap # calculate "External" current profile (e.g. ECCD) # form of external current on face grid - jext_source = source_models.jext - jext_face, jext = jext_source.get_value( + jext_face, jext = source_models.jext.get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_jext_params, geo=geo, - source_type=dynamic_config_slice.sources[jext_source.name].source_type, ) johm = jtot - jext - bootstrap_profile.j_bootstrap johm_face = jtot_face - jext_face - bootstrap_profile.j_bootstrap_face + # TODO(b/336995925): TORAX currently only uses the external current source, + # jext, when computing the jtot initial currents from psi. Really, though, we + # should be summing over all sources that can contribute current i.e. ECCD, + # ICRH, NBI, LHCD. jtot_hires = _get_jtot_hires( - dynamic_config_slice, geo, bootstrap_profile, Iohm, jext_source + dynamic_config_slice, + dynamic_jext_params, + geo, + bootstrap_profile, + Iohm, + source_models.jext, ) currents = state.Currents( @@ -498,7 +523,7 @@ def initial_core_profiles( static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, - source_models: source_models_lib.SourceModels | None = None, + source_models: source_models_lib.SourceModels, ) -> state.CoreProfiles: """Calculates the initial core profiles. @@ -506,20 +531,13 @@ def initial_core_profiles( static_config_slice: Static simulation configuration parameters. dynamic_config_slice: Dynamic configuration parameters at t=t_initial. geo: Torus geometry. - source_models: All models for TORAX sources/sinks. If not provided, uses the - default source_models. + source_models: All models for TORAX sources/sinks. Returns: Initial core profiles. """ # pylint: disable=invalid-name - source_models = ( - source_models_lib.SourceModels() - if source_models is None - else source_models - ) - # To set initial values and compute the boundary conditions, we need to handle # potentially time-varying inputs from the users. # The default time in build_dynamic_config_slice is t_initial @@ -636,7 +654,10 @@ def initial_core_profiles( psidot = dataclasses.replace( psidot, value=source_models_lib.calc_psidot( - dynamic_config_slice, geo, core_profiles, source_models, + dynamic_config_slice, + geo, + core_profiles, + source_models, ), ) @@ -776,7 +797,7 @@ def compute_boundary_conditions( dynamic_config_slice.profile_conditions.Ip / (jnp.pi * geo.Rmin**2) * 1e20 - / dynamic_config_slice.nref + / dynamic_config_slice.numerics.nref ) # pylint: enable=invalid-name ne_bound_right = jnp.where( @@ -828,6 +849,7 @@ def compute_boundary_conditions( # pylint: disable=invalid-name def _get_jtot_hires( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_jext_params: external_current_source.DynamicRuntimeParams, geo: Geometry, bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile, Iohm: jax.Array | float, @@ -840,8 +862,8 @@ def _get_jtot_hires( # calculate hi-res "External" current profile (e.g. ECCD) on cell grid. jext_hires = jext_source.jext_hires( - source_type=dynamic_config_slice.sources[jext_source.name].source_type, dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_jext_params, geo=geo, ) @@ -858,4 +880,21 @@ def _get_jtot_hires( return jtot_hires +def _get_jext_params( + dynamic_config_slice: config_slice.DynamicConfigSlice, + source_models: source_models_lib.SourceModels, +) -> external_current_source.DynamicRuntimeParams: + """Returns dynamic runtime params for the external current source.""" + assert source_models.jext_name in dynamic_config_slice.sources, ( + f'{source_models.jext_name} not found in dynamic_config_slice.sources.' + ' Check to make sure the DynamicConfigSlice was built with `sources`' + ' that include the external current source.' + ) + dynamic_jext_params = dynamic_config_slice.sources[source_models.jext_name] + assert isinstance( + dynamic_jext_params, external_current_source.DynamicRuntimeParams + ) + return dynamic_jext_params + + # pylint: enable=invalid-name diff --git a/torax/fvm/newton_raphson_solve_block.py b/torax/fvm/newton_raphson_solve_block.py index 6e5375c5..35bdf0bf 100644 --- a/torax/fvm/newton_raphson_solve_block.py +++ b/torax/fvm/newton_raphson_solve_block.py @@ -210,7 +210,6 @@ def newton_raphson_solve_block( # this is jitted. ( source_models_lib.build_all_zero_profiles( - dynamic_config_slice_t, geo, source_models, ), diff --git a/torax/fvm/optimizer_solve_block.py b/torax/fvm/optimizer_solve_block.py index f5a349f8..2af1d876 100644 --- a/torax/fvm/optimizer_solve_block.py +++ b/torax/fvm/optimizer_solve_block.py @@ -142,7 +142,6 @@ def optimizer_solve_block( # this is jitted. ( source_models_lib.build_all_zero_profiles( - dynamic_config_slice_t, geo, source_models, ), diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index a6ac18ba..f0b2a0c0 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -29,7 +29,8 @@ from torax import geometry from torax.fvm import implicit_solve_block from torax.fvm import residual_and_loss -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.tests.test_lib import torax_refs from torax.transport_model import constant as constant_transport_model @@ -354,22 +355,12 @@ def test_nonlinear_solve_block_loss_minimum( ), numerics=config_lib.Numerics( nr=num_cells, - Qei_mult=0, el_heat_eq=False, ), - Ptot=0, solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=theta_imp, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) geo = geometry.build_circular_geometry(config) transport_model = constant_transport_model.ConstantTransportModel( @@ -378,18 +369,29 @@ def test_nonlinear_solve_block_loss_minimum( chii_const=1, ), ) + source_models = default_sources.get_default_sources() + source_models.qei_source.runtime_params.Qei_mult = 0.0 + source_models.sources['generic_ion_el_heat_source'].runtime_params.Ptot = ( + 0.0 + ) + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) dynamic_config_slice = config_slice.build_dynamic_config_slice( config, transport=transport_model.runtime_params, + sources=source_models.runtime_params, ) static_config_slice = config_slice.build_static_config_slice(config) - source_models = source_models_lib.SourceModels() core_profiles = core_profile_setters.initial_core_profiles( static_config_slice, dynamic_config_slice, geo, source_models ) evolving_names = tuple(['temp_ion']) explicit_source_profiles = source_models_lib.build_source_profiles( - source_models=source_models_lib.SourceModels(), + source_models=source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=core_profiles, @@ -428,7 +430,7 @@ def test_nonlinear_solve_block_loss_minimum( # core_profiles_t_plus_dt is not updated since coeffs stay constant here loss, _ = residual_and_loss.theta_method_block_loss( dt=dt, - static_config_slice=config_slice.build_static_config_slice(config), + static_config_slice=static_config_slice, dynamic_config_slice_t_plus_dt=dynamic_config_slice, geo=geo, x_old=x_old, @@ -443,7 +445,7 @@ def test_nonlinear_solve_block_loss_minimum( residual, _ = residual_and_loss.theta_method_block_residual( dt=dt, - static_config_slice=config_slice.build_static_config_slice(config), + static_config_slice=static_config_slice, dynamic_config_slice_t_plus_dt=dynamic_config_slice, geo=geo, x_new_guess_vec=jnp.concatenate([var.value for var in x_new]), @@ -470,22 +472,12 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): ), numerics=config_lib.Numerics( nr=num_cells, - Qei_mult=0, el_heat_eq=False, ), - Ptot=0, solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=1.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) transport_model = constant_transport_model.ConstantTransportModel( runtime_params=constant_transport_model.RuntimeParams( @@ -493,9 +485,21 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): chii_const=1, ), ) + source_models = default_sources.get_default_sources() + source_models.qei_source.runtime_params.Qei_mult = 0.0 + source_models.sources['generic_ion_el_heat_source'].runtime_params.Ptot = ( + 0.0 + ) + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) dynamic_config_slice = config_slice.build_dynamic_config_slice( config, transport=transport_model.runtime_params, + sources=source_models.runtime_params, ) static_config_slice = config_slice.build_static_config_slice(config) geo = geometry.build_circular_geometry(config) @@ -588,22 +592,12 @@ def test_theta_residual_uses_updated_boundary_conditions(self): ), numerics=config_lib.Numerics( nr=num_cells, - Qei_mult=0, el_heat_eq=False, ), - Ptot=0, solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=0.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) geo = geometry.build_circular_geometry(config) transport_model = constant_transport_model.ConstantTransportModel( @@ -612,9 +606,21 @@ def test_theta_residual_uses_updated_boundary_conditions(self): chii_const=1, ), ) + source_models = default_sources.get_default_sources() + source_models.qei_source.runtime_params.Qei_mult = 0.0 + source_models.sources['generic_ion_el_heat_source'].runtime_params.Ptot = ( + 0.0 + ) + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) dynamic_config_slice = config_slice.build_dynamic_config_slice( config, transport=transport_model.runtime_params, + sources=source_models.runtime_params, ) static_config_slice_theta0 = config_slice.build_static_config_slice(config) static_config_slice_theta05 = dataclasses.replace( @@ -659,7 +665,10 @@ def test_theta_residual_uses_updated_boundary_conditions(self): right_face_constraint=initial_right_boundary, ) core_profiles_t_plus_dt = core_profile_setters.initial_core_profiles( - static_config_slice_theta0, dynamic_config_slice, geo + static_config_slice=static_config_slice_theta0, + dynamic_config_slice=dynamic_config_slice, + geo=geo, + source_models=source_models, ) core_profiles_t_plus_dt = dataclasses.replace( core_profiles_t_plus_dt, diff --git a/torax/runtime_params/config_slice_args.py b/torax/runtime_params/config_slice_args.py index 15c66017..59c5aa5a 100644 --- a/torax/runtime_params/config_slice_args.py +++ b/torax/runtime_params/config_slice_args.py @@ -17,6 +17,7 @@ from __future__ import annotations import dataclasses +import enum import types import typing from typing import Any @@ -56,8 +57,8 @@ def _check(ft): # below won't work, so we check for the full name here. ft == 'InterpParamOrInterpParamInput' or - # Common alias for InterpParamOrInterpParamInput. - ft == 'TimeDependentField' + # Common alias for InterpParamOrInterpParamInput in a few files. + (isinstance(ft, str) and 'TimeDependentField' in ft) or # Otherwise, only check if it is actually the InterpolatedParam. ft == 'interpolated_param.InterpolatedParam' @@ -118,5 +119,9 @@ def get_init_kwargs( config_val = interpolate_param(config_val, t) elif input_is_a_float_field(field.name, input_config_fields_to_types): config_val = float(config_val) + elif isinstance(config_val, enum.Enum): + config_val = config_val.value + elif hasattr(config_val, 'build_dynamic_params'): + config_val = config_val.build_dynamic_params(t) kwargs[field.name] = config_val return kwargs diff --git a/torax/sim.py b/torax/sim.py index 50a0b791..2d4d1c16 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -22,7 +22,6 @@ jax compilation off and on. Compilation is on by default. Turning compilation off can sometimes help with debugging (e.g. by making it easier to print error messages in context). - """ from __future__ import annotations @@ -430,8 +429,8 @@ def get_initial_state( static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, - time_step_calculator: ts.TimeStepCalculator, source_models: source_models_lib.SourceModels, + time_step_calculator: ts.TimeStepCalculator, ) -> state.ToraxSimState: """Returns the initial state to be used by run_simulation().""" initial_core_profiles = core_profile_setters.initial_core_profiles( @@ -664,8 +663,8 @@ def build_sim_from_config( geo: geometry.Geometry, stepper_builder: stepper_lib.StepperBuilder, transport_model: transport_model_lib.TransportModel, + source_models: source_models_lib.SourceModels, time_step_calculator: Optional[ts.TimeStepCalculator] = None, - source_models: source_models_lib.SourceModels | None = None, ) -> Sim: """Builds a Sim object from a Config file. @@ -680,35 +679,20 @@ def build_sim_from_config( stepper_builder: A callable to build the stepper. The stepper has already been factored out of the config. transport_model: Calculates diffusion and convection coefficients. - time_step_calculator: The time_step_calculator, if built, otherwise a - ChiTimeStepCalculator will be built by default. source_models: All TORAX sources/sink functions which provide profiles used as terms in the equations that evolve the core profiless. + time_step_calculator: The time_step_calculator, if built, otherwise a + ChiTimeStepCalculator will be built by default. Returns: sim: The built Sim instance. """ - source_models = ( - source_models_lib.SourceModels() - if source_models is None - else source_models - ) - - # Make sure the sources and the config (which contains the runtime configs for - # all the sources) have matching keys. - if set(source_models.all_sources.keys()) != set(config.sources.keys()): - raise ValueError( - 'SourceModels and config.sources must have the same keys. Mismatch ' - f'found.\nsource_models: {list(source_models.all_sources.keys())}.\n' - f'config.sources: {config.sources.keys()}' - ) static_config_slice = config_slice.build_static_config_slice(config) - dynamic_config_slice_provider = ( - config_slice.DynamicConfigSliceProvider( - config=config, - transport_getter=lambda: transport_model.runtime_params, - ) + dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( + config=config, + transport_getter=lambda: transport_model.runtime_params, + sources_getter=lambda: source_models.runtime_params, ) stepper = stepper_builder(transport_model, source_models) @@ -1035,6 +1019,9 @@ def update_current_distribution( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources[ + source_models.j_bootstrap_name + ], geo=geo, core_profiles=core_profiles, ) @@ -1184,10 +1171,10 @@ def _get_initial_source_profiles( ) qei = source_models.qei_source.get_qei( static_config_slice=static_config_slice, - source_type=dynamic_config_slice.sources[ - source_models.qei_source.name - ].source_type, dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources[ + source_models.qei_source_name + ], geo=geo, core_profiles=core_profiles, ) @@ -1218,11 +1205,11 @@ def _merge_source_profiles( explicit_source_profiles: Profiles from explicit source models. This SourceProfiles dict will include keys for both the explicit and implicit sources, but only the explicit sources will have non-zero profiles. See - source.py and source_config.py for more info on explicit vs. implicit. + source.py and runtime_params.py for more info on explicit vs. implicit. implicit_source_profiles: Profiles from implicit source models. This SourceProfiles dict will include keys for both the explicit and implicit sources, but only the implicit sources will have non-zero profiles. See - source.py and source_config.py for more info on explicit vs. implicit. + source.py and runtime_params.py for more info on explicit vs. implicit. source_models: Source models used to compute the profiles given. qei_core_profiles: The core profiles used to compute the Qei source. @@ -1246,7 +1233,7 @@ def _merge_source_profiles( ) # For ease of comprehension, we convert the units of the Qei source and add it # to the list of other profiles before returning it. - summed_other_profiles[source_models.qei_source.name] = ( + summed_other_profiles[source_models.qei_source_name] = ( summed_qei_info.qei_coef * (qei_core_profiles.temp_el.value - qei_core_profiles.temp_ion.value) ) diff --git a/torax/simulation_app.py b/torax/simulation_app.py index ff9e389d..54a04230 100644 --- a/torax/simulation_app.py +++ b/torax/simulation_app.py @@ -56,6 +56,7 @@ def get_sim(): from torax import geometry from torax import sim as sim_lib from torax import state as state_lib +from torax.sources import runtime_params as source_runtime_params_lib from torax.spectators import plotting from torax.transport_model import runtime_params as transport_runtime_params_lib import xarray as xr @@ -205,41 +206,60 @@ def update_sim( config: torax.Config, geo: geometry.Geometry, transport_runtime_params: transport_runtime_params_lib.RuntimeParams, + source_runtime_params: dict[str, source_runtime_params_lib.RuntimeParams], ) -> sim_lib.Sim: """Updates the sim with a new config and geometry.""" - # NOTE: This function will NOT update any of the following in the config: + # NOTE: This function will NOT update any of the following: # - stepper (for the mesh state) - # - transport model + # - transport model object (runtime params are updated) # - spectator # - time step calculator + # - source objects (runtime params are updated) # TODO(b/335596447): Add checks to ensure that SimulationStepFn can be reused # correctly given the new config. If any of the attributes above change, then # ether raise an error or build a new SimulationStepFn (and notify the user). # TODO(b/335596447): If the static slice is updated, add checks or logs # notifying the user that using this new config will result in recompiling # the SimulationStepFn. + sim.transport_model.runtime_params = transport_runtime_params + _update_source_params(sim, source_runtime_params) static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( + config=config, + transport_getter=lambda: sim.transport_model.runtime_params, + sources_getter=lambda: sim.source_models.runtime_params, + ) initial_state = sim_lib.get_initial_state( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + dynamic_config_slice=dynamic_config_slice_provider( + t=config.numerics.t_initial + ), static_config_slice=static_config_slice, geo=geo, time_step_calculator=sim.time_step_calculator, source_models=sim.source_models, ) - sim.transport_model.runtime_params = transport_runtime_params return sim_lib.Sim( time_step_calculator=sim.time_step_calculator, initial_state=initial_state, geometry_provider=sim_lib.ConstantGeometryProvider(geo), - dynamic_config_slice_provider=config_slice.DynamicConfigSliceProvider( - config=config, - transport_getter=lambda: sim.transport_model.runtime_params, - ), + dynamic_config_slice_provider=dynamic_config_slice_provider, static_config_slice=static_config_slice, step_fn=sim.step_fn, ) +def _update_source_params( + sim: sim_lib.Sim, + source_runtime_params: dict[str, source_runtime_params_lib.RuntimeParams], +) -> None: + for source_name, source_runtime_params in source_runtime_params.items(): + if source_name not in sim.source_models.sources: + raise ValueError(f'Source {source_name} not found in sim.') + sim.source_models.sources[source_name].runtime_params = ( + source_runtime_params + ) + + def can_plot() -> bool: # TODO(b/335596567): Find way to detect displays that works on all OS's. return True @@ -311,12 +331,10 @@ def main( ds = simulation_output_to_xr(torax_outputs, geo) write_simulation_output_to_file(output_dir, ds) - # TODO(b/335596701): Add back functionality to write configs to file after - # running to help with keeping track of simulation runs. if log_sim_output: - core_profile_history, _, _ = ( - state_lib.build_history_from_states(torax_outputs) + core_profile_history, _, _ = state_lib.build_history_from_states( + torax_outputs ) t = state_lib.build_time_history_from_states(torax_outputs) log_simulation_output_to_stdout(core_profile_history, geo, t) diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index 43f6c302..c4917e68 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -28,14 +28,140 @@ from torax import physics from torax import state from torax.fvm import cell_variable +from torax.runtime_params import config_slice_args +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config from torax.sources import source_profiles +@dataclasses.dataclass(kw_only=True) +class RuntimeParams(runtime_params_lib.RuntimeParams): + # Multiplication factor for bootstrap current + bootstrap_mult: float = 1.0 + + def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: + return DynamicRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + bootstrap_mult: float + + +def _default_output_shapes(geo) -> tuple[int, int, int, int]: + return ( + source.ProfileType.CELL.get_profile_shape(geo) # sigmaneo + + source.ProfileType.CELL.get_profile_shape(geo) # bootstrap + + source.ProfileType.FACE.get_profile_shape(geo) # bootstrap face + + (1,) # I_bootstrap + ) + + +@dataclasses.dataclass(kw_only=True) +class BootstrapCurrentSource(source.Source): + """Bootstrap current density source profile. + + Unlike other sources within TORAX, the BootstrapCurrentSource provides more + than just the bootstrap current. It uses neoclassical physics to determine + the following: + - sigmaneo + - bootstrap current (on cell and face grids) + - integrated bootstrap current + """ + + runtime_params: RuntimeParams = dataclasses.field( + default_factory=RuntimeParams, + ) + output_shape_getter: source.SourceOutputShapeFunction = _default_output_shapes + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.MODEL_BASED, + ) + + # Don't include affected_core_profiles in the __init__ arguments. + # Freeze this param. + affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( + dataclasses.field( + init=False, + default_factory=lambda: (source.AffectedCoreProfile.PSI,), + ) + ) + + def get_value( + self, + dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles | None = None, + temp_ion: cell_variable.CellVariable | None = None, + temp_el: cell_variable.CellVariable | None = None, + ne: cell_variable.CellVariable | None = None, + ni: cell_variable.CellVariable | None = None, + jtot_face: jnp.ndarray | None = None, + psi: cell_variable.CellVariable | None = None, + ) -> source_profiles.BootstrapCurrentProfile: + # Make sure the input mode requested is supported. + self.check_mode(dynamic_source_runtime_params.mode) + # Make sure the input params are the correct type. + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) + # Make sure the appropriate input args have been populated. + if not core_profiles and any([ + not temp_ion, + not temp_el, + ne is None, + ni is None, + jtot_face is None, + not psi, + ]): + raise ValueError( + 'If you cannot provide "core_profiles", then provide all of ' + 'temp_ion, temp_el, ne, ni, jtot_face, and psi.' + ) + # pytype: disable=attribute-error + temp_ion = temp_ion or core_profiles.temp_ion + temp_el = temp_el or core_profiles.temp_el + ne = ne if ne is not None else core_profiles.ne + ni = ni if ni is not None else core_profiles.ni + jtot_face = ( + jtot_face if jtot_face is not None else core_profiles.currents.jtot_face + ) + psi = psi or core_profiles.psi + # pytype: enable=attribute-error + return calc_neoclassical( + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_source_runtime_params, + geo=geo, + temp_ion=temp_ion, + temp_el=temp_el, + ne=ne, + ni=ni, + jtot_face=jtot_face, + psi=psi, + ) + + def get_source_profile_for_affected_core_profile( + self, + profile: chex.ArrayTree, + affected_core_profile: int, + geo: geometry.Geometry, + ) -> jnp.ndarray: + return jnp.where( + affected_core_profile in self.affected_core_profiles_ints, + profile['j_bootstrap'], + jnp.zeros_like(geo.r), + ) + + @jax_utils.jit def calc_neoclassical( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: DynamicRuntimeParams, geo: geometry.Geometry, temp_ion: cell_variable.CellVariable, temp_el: cell_variable.CellVariable, @@ -48,6 +174,7 @@ def calc_neoclassical( Args: dynamic_config_slice: General configuration parameters. + dynamic_source_runtime_params: Source-specific runtime parameters. geo: Torus geometry. temp_ion: Ion temperature. We don't pass in a full `State` here because this function is used to create the `Currents` in the initial `State`. @@ -68,8 +195,9 @@ def calc_neoclassical( # Formulas from Sauter PoP 1999. Future work can include Redl PoP 2021 # corrections. - true_ne_face = ne.face_value() * dynamic_config_slice.nref - true_ni_face = ni.face_value() * dynamic_config_slice.nref + true_ne_face = ne.face_value() * dynamic_config_slice.numerics.nref + true_ni_face = ni.face_value() * dynamic_config_slice.numerics.nref + Zeff = dynamic_config_slice.plasma_composition.Zeff # # local r/R0 on face grid epsilon = (geo.Rout_face - geo.Rin_face) / (geo.Rout_face + geo.Rin_face) @@ -80,7 +208,7 @@ def calc_neoclassical( ftrap = 1.0 - jnp.sqrt(aa) * (1.0 - epseff) / (1.0 + 2.0 * jnp.sqrt(epseff)) # Spitzer conductivity - NZ = 0.58 + 0.74 / (0.76 + dynamic_config_slice.plasma_composition.Zeff) + NZ = 0.58 + 0.74 / (0.76 + Zeff) # TODO(b/335599537): expand the log to get rid of the exponentiation, # sqrt, etc. lnLame = 31.3 - jnp.log(jnp.sqrt(true_ne_face) / (temp_el.face_value() * 1e3)) @@ -91,13 +219,7 @@ def calc_neoclassical( / ((temp_ion.face_value() * 1e3) ** 1.5) ) - sigsptz = ( - 1.9012e04 - * (temp_el.face_value() * 1e3) ** 1.5 - / dynamic_config_slice.plasma_composition.Zeff - / NZ - / lnLame - ) + sigsptz = 1.9012e04 * (temp_el.face_value() * 1e3) ** 1.5 / Zeff / NZ / lnLame # We don't store q_cell in the evolving core profiles, so we need to # recalculate it. @@ -112,7 +234,7 @@ def calc_neoclassical( * q_face * geo.Rmaj * true_ne_face - * dynamic_config_slice.plasma_composition.Zeff + * Zeff * lnLame / ( ((temp_el.face_value() * 1e3) ** 2) @@ -124,7 +246,7 @@ def calc_neoclassical( * q_face * geo.Rmaj * true_ni_face - * dynamic_config_slice.plasma_composition.Zeff**4 + * Zeff**4 * lnLami / ( ((temp_ion.face_value() * 1e3) ** 2) @@ -136,19 +258,10 @@ def calc_neoclassical( ft33 = ftrap / ( 1.0 + (0.55 - 0.1 * ftrap) * jnp.sqrt(nuestar) - + 0.45 - * (1.0 - ftrap) - * nuestar - / (dynamic_config_slice.plasma_composition.Zeff**1.5) + + 0.45 * (1.0 - ftrap) * nuestar / (Zeff**1.5) ) signeo_face = 1.0 - ft33 * ( - 1.0 - + 0.36 / dynamic_config_slice.plasma_composition.Zeff - - ft33 - * ( - 0.59 / dynamic_config_slice.plasma_composition.Zeff - - 0.23 / dynamic_config_slice.plasma_composition.Zeff * ft33 - ) + 1.0 + 0.36 / Zeff - ft33 * (0.59 / Zeff - 0.23 / Zeff * ft33) ) sigmaneo = sigsptz * signeo_face @@ -156,80 +269,54 @@ def calc_neoclassical( denom = ( 1.0 + (1 - 0.1 * ftrap) * jnp.sqrt(nuestar) - + 0.5 - * (1.0 - ftrap) - * nuestar - / dynamic_config_slice.plasma_composition.Zeff + + 0.5 * (1.0 - ftrap) * nuestar / Zeff ) ft31 = ftrap / denom ft32ee = ftrap / ( 1 + 0.26 * (1 - ftrap) * jnp.sqrt(nuestar) - + 0.18 - * (1 - 0.37 * ftrap) - * nuestar - / jnp.sqrt(dynamic_config_slice.plasma_composition.Zeff) + + 0.18 * (1 - 0.37 * ftrap) * nuestar / jnp.sqrt(Zeff) ) ft32ei = ftrap / ( 1 + (1 + 0.6 * ftrap) * jnp.sqrt(nuestar) - + 0.85 - * (1 - 0.37 * ftrap) - * nuestar - * (1 + dynamic_config_slice.plasma_composition.Zeff) + + 0.85 * (1 - 0.37 * ftrap) * nuestar * (1 + Zeff) ) ft34 = ftrap / ( 1.0 + (1 - 0.1 * ftrap) * jnp.sqrt(nuestar) - + 0.5 - * (1.0 - 0.5 * ftrap) - * nuestar - / dynamic_config_slice.plasma_composition.Zeff + + 0.5 * (1.0 - 0.5 * ftrap) * nuestar / Zeff ) F32ee = ( - (0.05 + 0.62 * dynamic_config_slice.plasma_composition.Zeff) - / ( - dynamic_config_slice.plasma_composition.Zeff - * (1 + 0.44 * dynamic_config_slice.plasma_composition.Zeff) - ) - * (ft32ee - ft32ee**4) + (0.05 + 0.62 * Zeff) / (Zeff * (1 + 0.44 * Zeff)) * (ft32ee - ft32ee**4) + 1 - / (1 + 0.22 * dynamic_config_slice.plasma_composition.Zeff) + / (1 + 0.22 * Zeff) * (ft32ee**2 - ft32ee**4 - 1.2 * (ft32ee**3 - ft32ee**4)) - + 1.2 - / (1 + 0.5 * dynamic_config_slice.plasma_composition.Zeff) - * ft32ee**4 + + 1.2 / (1 + 0.5 * Zeff) * ft32ee**4 ) F32ei = ( - -(0.56 + 1.93 * dynamic_config_slice.plasma_composition.Zeff) - / ( - dynamic_config_slice.plasma_composition.Zeff - * (1 + 0.44 * dynamic_config_slice.plasma_composition.Zeff) - ) - * (ft32ei - ft32ei**4) + -(0.56 + 1.93 * Zeff) / (Zeff * (1 + 0.44 * Zeff)) * (ft32ei - ft32ei**4) + 4.95 - / (1 + 2.48 * dynamic_config_slice.plasma_composition.Zeff) + / (1 + 2.48 * Zeff) * (ft32ei**2 - ft32ei**4 - 0.55 * (ft32ei**3 - ft32ei**4)) - - 1.2 - / (1 + 0.5 * dynamic_config_slice.plasma_composition.Zeff) - * ft32ei**4 + - 1.2 / (1 + 0.5 * Zeff) * ft32ei**4 ) - term_0 = (1 + 1.4 / (dynamic_config_slice.plasma_composition.Zeff + 1)) * ft31 - term_1 = -1.9 / (dynamic_config_slice.plasma_composition.Zeff + 1) * ft31**2 - term_2 = 0.3 / (dynamic_config_slice.plasma_composition.Zeff + 1) * ft31**3 - term_3 = 0.2 / (dynamic_config_slice.plasma_composition.Zeff + 1) * ft31**4 + term_0 = (1 + 1.4 / (Zeff + 1)) * ft31 + term_1 = -1.9 / (Zeff + 1) * ft31**2 + term_2 = 0.3 / (Zeff + 1) * ft31**3 + term_3 = 0.2 / (Zeff + 1) * ft31**4 L31 = term_0 + term_1 + term_2 + term_3 L32 = F32ee + F32ei L34 = ( - (1 + 1.4 / (dynamic_config_slice.plasma_composition.Zeff + 1)) * ft34 - - 1.9 / (dynamic_config_slice.plasma_composition.Zeff + 1) * ft34**2 - + 0.3 / (dynamic_config_slice.plasma_composition.Zeff + 1) * ft34**3 - + 0.2 / (dynamic_config_slice.plasma_composition.Zeff + 1) * ft34**4 + (1 + 1.4 / (Zeff + 1)) * ft34 + - 1.9 / (Zeff + 1) * ft34**2 + + 0.3 / (Zeff + 1) * ft34**3 + + 0.2 / (Zeff + 1) * ft34**4 ) alpha0 = -1.17 * (1 - ftrap) / (1 - 0.22 * ftrap - 0.19 * ftrap**2) @@ -245,7 +332,7 @@ def calc_neoclassical( # calculate bootstrap current prefactor = ( -geo.F_face - * dynamic_config_slice.numerics.bootstrap_mult + * dynamic_source_runtime_params.bootstrap_mult * 2 * jnp.pi / geo.B0 @@ -291,105 +378,3 @@ def calc_neoclassical( j_bootstrap_face=j_bootstrap_face, I_bootstrap=I_bootstrap, ) - - -# TODO(b/314308399): Remove this as a source and create a -# Neoclassical class which will compute sigmaneo, j_bootstrap, etc. - - -def _default_output_shapes( - unused_config, geo, unused_core_profiles -) -> tuple[int, int, int, int]: - # TODO(b/314308399): When refactoring neoclassical "sources", - # revisit what we actually need to return here. - return ( - source.ProfileType.CELL.get_profile_shape(geo) # sigmaneo - + source.ProfileType.CELL.get_profile_shape(geo) # bootstrap - + source.ProfileType.FACE.get_profile_shape(geo) # bootstrap face - + (1,) # I_bootstrap - ) - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class BootstrapCurrentSource(source.Source): - """Bootstrap current density source profile. - - Unlike other sources within TORAX, the BootstrapCurrentSource provides more - than just the bootstrap current. It uses neoclassical physics to determine - the following: - - sigmaneo - - bootstrap current (on cell and face grids) - - integrated bootstrap current - """ - - name: str = 'j_bootstrap' - output_shape_getter: source.SourceOutputShapeFunction = _default_output_shapes - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.MODEL_BASED, - ) - - # Don't include affected_core_profiles in the __init__ arguments. - # Freeze this param. - affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( - dataclasses.field( - init=False, - default_factory=lambda: (source.AffectedCoreProfile.PSI,), - ) - ) - - def get_value( - self, - dynamic_config_slice: config_slice.DynamicConfigSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles | None = None, - temp_ion: cell_variable.CellVariable | None = None, - temp_el: cell_variable.CellVariable | None = None, - ne: cell_variable.CellVariable | None = None, - ni: cell_variable.CellVariable | None = None, - jtot_face: jnp.ndarray | None = None, - psi: cell_variable.CellVariable | None = None, - ) -> source_profiles.BootstrapCurrentProfile: - if not core_profiles and any([ - not temp_ion, - not temp_el, - ne is None, - ni is None, - jtot_face is None, - not psi, - ]): - raise ValueError( - 'If you cannot provide "core_profiles", then provide all of ' - 'temp_ion, temp_el, ne, ni, jtot_face, and psi.' - ) - # pytype: disable=attribute-error - temp_ion = temp_ion or core_profiles.temp_ion - temp_el = temp_el or core_profiles.temp_el - ne = ne if ne is not None else core_profiles.ne - ni = ni if ni is not None else core_profiles.ni - jtot_face = ( - jtot_face if jtot_face is not None else core_profiles.currents.jtot_face - ) - psi = psi or core_profiles.psi - # pytype: enable=attribute-error - return calc_neoclassical( - dynamic_config_slice, - geo, - temp_ion=temp_ion, - temp_el=temp_el, - ne=ne, - ni=ni, - jtot_face=jtot_face, - psi=psi, - ) - - def get_source_profile_for_affected_core_profile( - self, - profile: chex.ArrayTree, - affected_core_profile: int, - geo: geometry.Geometry, - ) -> jnp.ndarray: - return jnp.where( - affected_core_profile in self.affected_core_profiles_ints, - profile['j_bootstrap'], - jnp.zeros_like(geo.r), - ) diff --git a/torax/sources/current_density_sources.py b/torax/sources/current_density_sources.py index 6d49ab3a..4b387df2 100644 --- a/torax/sources/current_density_sources.py +++ b/torax/sources/current_density_sources.py @@ -24,35 +24,27 @@ # The current sources below don't have any source-specific implementations, so -# their bodies are mostly empty. You can refer to their base class to see the +# their bodies are empty. You can refer to their base class to see the # implementation. We define new classes here to: # a) support any future source-specific implementation. # b) better readability and human-friendly error messages when debugging. -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class ECRHCurrentSource(source.SingleProfilePsiSource): """ECRH current density source for the psi equation.""" - name: str = 'ecrh_current_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class ICRHCurrentSource(source.SingleProfilePsiSource): """ICRH current density source for the psi equation.""" - name: str = 'icrh_current_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class LHCurrentSource(source.SingleProfilePsiSource): """LH current density source for the psi equation.""" - name: str = 'lh_current_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class NBICurrentSource(source.SingleProfilePsiSource): """NBI current density source for the psi equation.""" - - name: str = 'nbi_current_source' diff --git a/torax/sources/default_sources.py b/torax/sources/default_sources.py new file mode 100644 index 00000000..744cdce5 --- /dev/null +++ b/torax/sources/default_sources.py @@ -0,0 +1,118 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default collection of sources and their runtime params. + +These are mainly useful for tests, but also they can serve as starting points +when defining new TORAX configs in general. + +Many TORAX test files and example configurations use these defaults with minor +tweaks added on top. +""" + +from torax.sources import bootstrap_current_source +from torax.sources import electron_density_sources +from torax.sources import external_current_source +from torax.sources import fusion_heat_source +from torax.sources import generic_ion_el_heat_source as ion_el_heat +from torax.sources import qei_source +from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source_models as source_models_lib + + +def get_default_sources() -> source_models_lib.SourceModels: + """Returns a SourceModels containing default sources and runtime parameters. + + This set of sources and params are used by most of the TORAX test + configurations, including ITER-inpired configs, with additional changes to + their runtime configurations on top. + + If you plan to use them, please remember to update the default runtime + parameters as needed. Here is an example of how to do so: + + ```python + default_sources: SourceModels = get_default_sources() + # Turn off bootstrap current. + default_sources.j_bootstrap.runtime_params.mode = runtime_params.Mode.ZERO + # Change the Qei ion-electron heat exchange term. + default_sources.qei_source.runtime_params.Qei_mult = 2.0 + # Turn off fusion power. + default_sources.sources['fusion_heat_source'].runtime_params.mode = ( + runtime_params.Mode.ZERO + ) + ``` + + More examples are located in the test config files under + `torax/tests/test_data`. + """ + source_models = source_models_lib.SourceModels( + sources={ + # Current sources (for psi equation) + 'j_bootstrap': bootstrap_current_source.BootstrapCurrentSource( + runtime_params=bootstrap_current_source.RuntimeParams( + mode=runtime_params_lib.Mode.MODEL_BASED, + ), + ), + 'jext': external_current_source.ExternalCurrentSource( + runtime_params=external_current_source.RuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + ), + ), + # Electron density sources/sink (for the ne equation). + 'nbi_particle_source': electron_density_sources.NBIParticleSource( + runtime_params=electron_density_sources.NBIParticleRuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + ), + ), + 'gas_puff_source': electron_density_sources.GasPuffSource( + runtime_params=electron_density_sources.GasPuffRuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + ), + ), + 'pellet_source': electron_density_sources.PelletSource( + runtime_params=electron_density_sources.PelletRuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + ), + ), + # Ion and electron heat sources (for the temp-ion and temp-el eqs). + 'generic_ion_el_heat_source': ( + ion_el_heat.GenericIonElectronHeatSource( + runtime_params=ion_el_heat.RuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + ), + ) + ), + 'fusion_heat_source': fusion_heat_source.FusionHeatSource( + runtime_params=runtime_params_lib.RuntimeParams( + mode=runtime_params_lib.Mode.MODEL_BASED, + ), + ), + 'qei_source': qei_source.QeiSource( + runtime_params=qei_source.RuntimeParams( + mode=runtime_params_lib.Mode.MODEL_BASED, + ), + ), + } + ) + # Add OhmicHeatSource after because it requires a pointer to the SourceModels. + source_models.add_source( + source_name='ohmic_heat_source', + source=source_models_lib.OhmicHeatSource( + source_models=source_models, + runtime_params=runtime_params_lib.RuntimeParams( + mode=runtime_params_lib.Mode.MODEL_BASED, + ), + ), + ) + return source_models diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index fb14ad5c..9b00c40d 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -18,112 +18,207 @@ import dataclasses +import chex from jax import numpy as jnp +from torax import config_slice from torax import geometry +from torax import state +from torax.runtime_params import config_slice_args from torax.sources import formulas +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config -def calc_puff_source( +# pylint: disable=invalid-name + + +@dataclasses.dataclass(kw_only=True) +class GasPuffRuntimeParams(runtime_params_lib.RuntimeParams): + # exponential decay length of gas puff ionization [normalized radial coord] + puff_decay_length: runtime_params_lib.TimeDependentField = 0.05 + # total gas puff particles/s + S_puff_tot: runtime_params_lib.TimeDependentField = 1e22 + + def build_dynamic_params( + self, + t: chex.Numeric, + ) -> DynamicGasPuffRuntimeParams: + return DynamicGasPuffRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicGasPuffRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + puff_decay_length: float + S_puff_tot: float + + +# Default formula: exponential with nref normalization. +def _calc_puff_source( + dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, - puff_decay_length: float, - S_puff_tot: float, # pylint: disable=invalid-name - nref: float, + unused_state: state.CoreProfiles | None = None, ) -> jnp.ndarray: """Calculates external source term for n from puffs.""" + assert isinstance(dynamic_source_runtime_params, DynamicGasPuffRuntimeParams) return formulas.exponential_profile( c1=1.0, - c2=puff_decay_length, - total=S_puff_tot / nref, + c2=dynamic_source_runtime_params.puff_decay_length, + total=( + dynamic_source_runtime_params.S_puff_tot + / dynamic_config_slice.numerics.nref + ), use_normalized_r=True, geo=geo, ) -def calc_nbi_source( - geo: geometry.Geometry, - nbi_deposition_location: float, - nbi_particle_width: float, - S_nbi_tot: float, # pylint: disable=invalid-name - nref: float, -) -> jnp.ndarray: - """Calculates external source term for n from SBI.""" - return formulas.gaussian_profile( - c1=nbi_deposition_location, - c2=nbi_particle_width, - total=S_nbi_tot / nref, - use_normalized_r=True, - geo=geo, +@dataclasses.dataclass(kw_only=True) +class GasPuffSource(source.SingleProfileNeSource): + """Gas puff source for the ne equation.""" + + runtime_params: GasPuffRuntimeParams = dataclasses.field( + default_factory=GasPuffRuntimeParams ) + formula: source.SourceProfileFunction = _calc_puff_source + + +@dataclasses.dataclass(kw_only=True) +class NBIParticleRuntimeParams(runtime_params_lib.RuntimeParams): + """Runtime parameters for NBI particle source.""" + + # NBI particle source Gaussian width in normalized radial coord + nbi_particle_width: runtime_params_lib.TimeDependentField = 0.25 + # NBI particle source Gaussian central location in normalized radial coord + nbi_deposition_location: runtime_params_lib.TimeDependentField = 0.0 + # NBI total particle source + S_nbi_tot: runtime_params_lib.TimeDependentField = 1e22 + + def build_dynamic_params( + self, + t: chex.Numeric, + ) -> DynamicNBIParticleRuntimeParams: + return DynamicNBIParticleRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicNBIParticleRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class DynamicNBIParticleRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + nbi_particle_width: float + nbi_deposition_location: float + S_nbi_tot: float -def calc_pellet_source( + +def _calc_nbi_source( + dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, - pellet_deposition_location: float, - pellet_width: float, - S_pellet_tot: float, # pylint: disable=invalid-name - nref: float, + unused_state: state.CoreProfiles | None = None, ) -> jnp.ndarray: - """Calculates external source term for n from pellets.""" + """Calculates external source term for n from SBI.""" + assert isinstance( + dynamic_source_runtime_params, DynamicNBIParticleRuntimeParams + ) return formulas.gaussian_profile( - c1=pellet_deposition_location, - c2=pellet_width, - total=S_pellet_tot / nref, + c1=dynamic_source_runtime_params.nbi_deposition_location, + c2=dynamic_source_runtime_params.nbi_particle_width, + total=( + dynamic_source_runtime_params.S_nbi_tot + / dynamic_config_slice.numerics.nref + ), use_normalized_r=True, geo=geo, ) -@dataclasses.dataclass(frozen=True, kw_only=True) -class GasPuffSource(source.SingleProfileNeSource): - """Gas puff source for the ne equation.""" - - name: str = 'gas_puff_source' - - formula: source_config.SourceProfileFunction = ( - lambda dcs, geo, unused_state: calc_puff_source( - geo, - puff_decay_length=dcs.puff_decay_length, - S_puff_tot=dcs.S_puff_tot, - nref=dcs.nref, - ) - ) - - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class NBIParticleSource(source.SingleProfileNeSource): """Neutral-beam injection source for the ne equation.""" - name: str = 'nbi_particle_source' + runtime_params: NBIParticleRuntimeParams = dataclasses.field( + default_factory=NBIParticleRuntimeParams + ) - formula: source_config.SourceProfileFunction = ( - lambda dcs, geo, unused_state: calc_nbi_source( - geo, - nbi_deposition_location=dcs.nbi_deposition_location, - nbi_particle_width=dcs.nbi_particle_width, - S_nbi_tot=dcs.S_nbi_tot, - nref=dcs.nref, - ) + formula: source.SourceProfileFunction = _calc_nbi_source + + +@dataclasses.dataclass(kw_only=True) +class PelletRuntimeParams(runtime_params_lib.RuntimeParams): + """Runtime parameters for PelletSource.""" + + # Gaussian width of pellet deposition [normalized radial coord], + # (continuous pellet model) + pellet_width: runtime_params_lib.TimeDependentField = 0.1 + # Pellet source Gaussian central location [normalized radial coord] + # (continuous pellet model) + pellet_deposition_location: runtime_params_lib.TimeDependentField = 0.85 + # total pellet particles/s (continuous pellet model) + S_pellet_tot: runtime_params_lib.TimeDependentField = 2e22 + + def build_dynamic_params( + self, + t: chex.Numeric, + ) -> DynamicPelletRuntimeParams: + return DynamicPelletRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicPelletRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class DynamicPelletRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + pellet_width: float + pellet_deposition_location: float + S_pellet_tot: float + + +def _calc_pellet_source( + dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, + geo: geometry.Geometry, + unused_state: state.CoreProfiles | None = None, +) -> jnp.ndarray: + """Calculates external source term for n from pellets.""" + assert isinstance(dynamic_source_runtime_params, DynamicPelletRuntimeParams) + return formulas.gaussian_profile( + c1=dynamic_source_runtime_params.pellet_deposition_location, + c2=dynamic_source_runtime_params.pellet_width, + total=( + dynamic_source_runtime_params.S_pellet_tot + / dynamic_config_slice.numerics.nref + ), + use_normalized_r=True, + geo=geo, ) -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class PelletSource(source.SingleProfileNeSource): """Pellet source for the ne equation.""" - name: str = 'pellet_source' - - formula: source_config.SourceProfileFunction = ( - lambda dcs, geo, unused_state: calc_pellet_source( - geo, - pellet_deposition_location=dcs.pellet_deposition_location, - pellet_width=dcs.pellet_width, - S_pellet_tot=dcs.S_pellet_tot, - nref=dcs.nref, - ) + runtime_params: PelletRuntimeParams = dataclasses.field( + default_factory=PelletRuntimeParams ) + formula: source.SourceProfileFunction = _calc_pellet_source + + +# pylint: enable=invalid-name # The sources below don't have any source-specific implementations, so their # bodies are empty. You can refer to their base class to see the implementation. @@ -132,8 +227,6 @@ class PelletSource(source.SingleProfileNeSource): # b) better readability and human-friendly error messages when debugging. -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class RecombinationDensitySink(source.SingleProfileNeSource): """Recombination sink for the electron density equation.""" - - name: str = 'recombination_density_sink' diff --git a/torax/sources/external_current_source.py b/torax/sources/external_current_source.py index d3036787..d5ed5a83 100644 --- a/torax/sources/external_current_source.py +++ b/torax/sources/external_current_source.py @@ -23,95 +23,171 @@ from jax.scipy import integrate from torax import config_slice from torax import geometry +from torax import jax_utils from torax import state +from torax.runtime_params import config_slice_args +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config -_trapz = integrate.trapezoid +# pylint: disable=invalid-name -def calculate_Iext( # pylint: disable=invalid-name - dynamic_config_slice: config_slice.DynamicConfigSlice, -) -> float: - """Calculates the total value of external current.""" - if dynamic_config_slice.use_absolute_jext: - return dynamic_config_slice.Iext - else: - return ( - dynamic_config_slice.profile_conditions.Ip * dynamic_config_slice.fext +@dataclasses.dataclass(kw_only=True) +class RuntimeParams(runtime_params_lib.RuntimeParams): + """Runtime parameters for the external current source.""" + + # total "external" current in MA. Used if use_absolute_jext=True. + Iext: runtime_params_lib.TimeDependentField = 3.0 + # total "external" current fraction. Used if use_absolute_jext=False. + fext: runtime_params_lib.TimeDependentField = 0.2 + # width of "external" Gaussian current profile + wext: runtime_params_lib.TimeDependentField = 0.05 + # normalized radius of "external" Gaussian current profile + rext: runtime_params_lib.TimeDependentField = 0.4 + + # Toggles if external current is provided absolutely or as a fraction of Ip. + use_absolute_jext: bool = False + + def build_dynamic_params( + self, + t: chex.Numeric, + ) -> DynamicRuntimeParams: + return DynamicRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicRuntimeParams, + t=t, + ) ) -def calculate_jext_face( +@chex.dataclass(frozen=True) +class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + """Dynamic runtime parameters for the external current source.""" + + Iext: float + fext: float + wext: float + rext: float + use_absolute_jext: bool + + def sanity_check(self): + """Checks that all parameters are valid.""" + # Using object.__setattr__ to get around the fact this is a frozen dataclass + object.__setattr__( + self, 'wext', jax_utils.error_if_negative(self.wext, 'wext') + ) + + def __post_init__(self): + self.sanity_check() + + +_trapz = integrate.trapezoid + + +def _calculate_jext_face( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + unused_state: state.CoreProfiles | None = None, ) -> jnp.ndarray: """Calculates the external current density profiles. Args: dynamic_config_slice: Parameter configuration at present timestep. + dynamic_source_runtime_params: Source-specific parameters at the present + timestep. geo: Tokamak geometry. + unused_state: State argument not used in this function but is present to + adhere to the source API. Returns: External current density profile along the face grid. """ - # pylint: disable=invalid-name - Iext = calculate_Iext(dynamic_config_slice) + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) + Iext = _calculate_Iext( + dynamic_config_slice, + dynamic_source_runtime_params, + ) # form of external current on face grid jextform_face = jnp.exp( - -((geo.r_face_norm - dynamic_config_slice.rext) ** 2) - / (2 * dynamic_config_slice.wext**2) + -((geo.r_face_norm - dynamic_source_runtime_params.rext) ** 2) + / (2 * dynamic_source_runtime_params.wext**2) ) Cext = Iext * 1e6 / _trapz(jextform_face * geo.spr_face, geo.r_face) jext_face = Cext * jextform_face # external current profile - # pylint: enable=invalid-name return jext_face -def calculate_jext_hires( +def _calculate_jext_hires( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + unused_state: state.CoreProfiles | None = None, ) -> jnp.ndarray: """Calculates the external current density profile along the hires grid. Args: dynamic_config_slice: Parameter configuration at present timestep. + dynamic_source_runtime_params: Source-specific parameters at the present + timestep. geo: Tokamak geometry. + unused_state: State argument not used in this function but is present to + adhere to the source API. Returns: External current density profile along the hires cell grid. """ - # pylint: disable=invalid-name - Iext = calculate_Iext(dynamic_config_slice) + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) + Iext = _calculate_Iext( + dynamic_config_slice, + dynamic_source_runtime_params, + ) # calculate "External" current profile (e.g. ECCD) # form of external current on cell grid jextform_hires = jnp.exp( - -((geo.r_hires_norm - dynamic_config_slice.rext) ** 2) - / (2 * dynamic_config_slice.wext**2) + -((geo.r_hires_norm - dynamic_source_runtime_params.rext) ** 2) + / (2 * dynamic_source_runtime_params.wext**2) ) Cext_hires = Iext * 1e6 / _trapz(jextform_hires * geo.spr_hires, geo.r_hires) # External current profile on cell grid jext_hires = Cext_hires * jextform_hires - # pylint: enable=invalid-name return jext_hires -@dataclasses.dataclass(frozen=True, kw_only=True) +def _calculate_Iext( + dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: DynamicRuntimeParams, +) -> chex.Numeric: + """Calculates the total value of external current.""" + return jnp.where( + dynamic_source_runtime_params.use_absolute_jext, + dynamic_source_runtime_params.Iext, + ( + dynamic_config_slice.profile_conditions.Ip + * dynamic_source_runtime_params.fext + ), + ) + + +@dataclasses.dataclass(kw_only=True) class ExternalCurrentSource(source.Source): """External current density source profile.""" - name: str = 'jext' + runtime_params: RuntimeParams = dataclasses.field( + default_factory=RuntimeParams + ) - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.FORMULA_BASED, - source_config.SourceType.ZERO, + supported_types: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, ) # Don't include affected_core_profiles in the __init__ arguments. - # Freeze this param. + # "Freeze" this param. affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( dataclasses.field( init=False, @@ -119,51 +195,50 @@ class ExternalCurrentSource(source.Source): ) ) + formula: source.SourceProfileFunction = _calculate_jext_face + hires_formula: source.SourceProfileFunction = _calculate_jext_hires + def get_value( self, - source_type: int, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Return the external current density profile along face and cell grids.""" - source_type = self.check_source_type(source_type) + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) + self.check_mode(dynamic_source_runtime_params.mode) profile = source.get_source_profiles( - source_type=source_type, dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, # There is no model implementation. model_func=( - lambda _0, _1, _2: source.ProfileType.FACE.get_zero_profile(geo) - ), - formula=lambda dcs, g, _: calculate_jext_face( - dcs, - g, + lambda _0, _1, _2, _3: source.ProfileType.FACE.get_zero_profile(geo) ), + formula=self.formula, output_shape=source.ProfileType.FACE.get_profile_shape(geo), ) return profile, geometry.face_to_cell(profile) def jext_hires( self, - source_type: int, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, ) -> jnp.ndarray: """Return the external current density profile along the hires cell grid.""" - source_type = self.check_source_type(source_type) + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) + self.check_mode(dynamic_source_runtime_params.mode) return source.get_source_profiles( - source_type=source_type, dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=None, # There is no model for this source. - model_func=(lambda _0, _1, _2: jnp.zeros_like(geo.r_hires_norm)), - formula=lambda dcs, g, _: calculate_jext_hires( - dcs, - g, - ), + model_func=(lambda _0, _1, _2, _3: jnp.zeros_like(geo.r_hires_norm)), + formula=self.hires_formula, output_shape=geo.r_hires_norm.shape, ) diff --git a/torax/sources/formula_config.py b/torax/sources/formula_config.py index c4860497..140a552c 100644 --- a/torax/sources/formula_config.py +++ b/torax/sources/formula_config.py @@ -18,7 +18,9 @@ import dataclasses +import chex from torax import interpolated_param +from torax.runtime_params import config_slice_args # Type-alias for clarity. @@ -26,7 +28,30 @@ @dataclasses.dataclass -class Exponential: +class FormulaConfig: + """Configures a formula. + + This config can include time-varying parameters which are interpolated as + the simulation runs. For new formula implementations, extend this class and + add the formula-specific parameters required. + + The Gaussian and Exponential config classes, and their implementations in + formulas.py, are useful, simple examples for how to do this. + """ + + def build_dynamic_params(self, t: chex.Numeric) -> DynamicFormula: + """Interpolates this config to a dynamic config for time t.""" + del t # Unused because there are no params in the base class. + return DynamicFormula() + + +@chex.dataclass(frozen=True) +class DynamicFormula: + """Base class for dynamic configs.""" + + +@dataclasses.dataclass +class Exponential(FormulaConfig): """Configures an exponential formula. See formulas.Exponential for more information on how this config is used. @@ -40,6 +65,24 @@ class Exponential: # If True, uses r_norm when calculating the source profiles. use_normalized_r: bool = False + def build_dynamic_params(self, t: chex.Numeric) -> DynamicExponential: + return DynamicExponential( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicExponential, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class DynamicExponential(DynamicFormula): + + total: float + c1: float + c2: float + use_normalized_r: bool + @dataclasses.dataclass class Gaussian: @@ -56,13 +99,21 @@ class Gaussian: # If True, uses r_norm when calculating the source profiles. use_normalized_r: bool = False + def build_dynamic_params(self, t: chex.Numeric) -> DynamicGaussian: + return DynamicGaussian( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicGaussian, + t=t, + ) + ) -@dataclasses.dataclass -class FormulaConfig: - """Contains all formula configs.""" - exponential: Exponential = dataclasses.field(default_factory=Exponential) - gaussian: Gaussian = dataclasses.field(default_factory=Gaussian) - custom_params: dict[str, TimeDependentField] = dataclasses.field( - default_factory=lambda: {}, - ) +@chex.dataclass(frozen=True) +class DynamicGaussian(DynamicFormula): + + total: float + c1: float + c2: float + # If True, uses r_norm when calculating the source profiles. + use_normalized_r: bool diff --git a/torax/sources/formulas.py b/torax/sources/formulas.py index 2f89ada5..cc4a0577 100644 --- a/torax/sources/formulas.py +++ b/torax/sources/formulas.py @@ -21,6 +21,8 @@ from torax import geometry from torax import jax_utils from torax import state +from torax.sources import formula_config +from torax.sources import runtime_params # Many variables throughout this function are capitalized based on physics @@ -108,27 +110,17 @@ def gaussian_profile( @dataclasses.dataclass(frozen=True) class Exponential: - """Callable class providing an exponential profile. - - It uses the runtime config config_slice.DynamicConfigSlice to get the correct - parameters and returns an exponential profile on the cell grid. - - Attributes: - source_name: Name of the source this formula is attached to. This helps grab - the relevant SourceConfig from the DynamicConfigSlice. - """ - - source_name: str + """Callable class providing an exponential profile.""" def __call__( self, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None, ) -> jnp.ndarray: - exp_config = dynamic_config_slice.sources[ - self.source_name - ].formula.exponential + exp_config = dynamic_source_runtime_params.formula + assert isinstance(exp_config, formula_config.DynamicExponential) return exponential_profile( c1=exp_config.c1, c2=exp_config.c2, @@ -140,27 +132,17 @@ def __call__( @dataclasses.dataclass(frozen=True) class Gaussian: - """Callable class providing a gaussian profile. - - It uses the runtime config config_slice.DynamicConfigSlice to get the correct - parameters and returns a gaussian profile on the cell grid. - - Attributes: - source_name: Name of the source this formula is attached to. This helps grab - the relevant SourceConfig from the DynamicConfigSlice. - """ - - source_name: str + """Callable class providing a gaussian profile.""" def __call__( self, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None, ) -> jnp.ndarray: - gaussian_config = dynamic_config_slice.sources[ - self.source_name - ].formula.gaussian + gaussian_config = dynamic_source_runtime_params.formula + assert isinstance(gaussian_config, formula_config.DynamicGaussian) return gaussian_profile( c1=gaussian_config.c1, c2=gaussian_config.c2, diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index 5b1a74a6..717ef2bc 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -24,8 +24,8 @@ from torax import constants from torax import geometry from torax import state +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config def calc_fusion( @@ -122,24 +122,26 @@ def calc_fusion( def fusion_heat_model_func( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> jnp.ndarray: + del dynamic_source_runtime_params # Unused. # pylint: disable=invalid-name - _, Pfus_i, Pfus_e = calc_fusion(geo, core_profiles, dynamic_config_slice.nref) + _, Pfus_i, Pfus_e = calc_fusion( + geo, core_profiles, dynamic_config_slice.numerics.nref + ) return jnp.stack((Pfus_i, Pfus_e)) # pylint: enable=invalid-name -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class FusionHeatSource(source.IonElectronSource): """Fusion heat source for both ion and electron heat.""" - name: str = 'fusion_heat_source' - - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.ZERO, - source_config.SourceType.MODEL_BASED, + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.MODEL_BASED, ) - model_func: source_config.SourceProfileFunction = fusion_heat_model_func + model_func: source.SourceProfileFunction = fusion_heat_model_func diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index fda2d688..48799d6b 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -18,13 +18,52 @@ import dataclasses +import chex import jax from jax import numpy as jnp from torax import config_slice from torax import geometry from torax import state +from torax.runtime_params import config_slice_args +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config + + +# Many variables throughout this function are capitalized based on physics +# notational conventions rather than on Google Python style +# pylint: disable=invalid-name + + +@dataclasses.dataclass(kw_only=True) +class RuntimeParams(runtime_params_lib.RuntimeParams): + """Runtime parameters for the generic heat source.""" + + # external heat source parameters + # Gaussian width in normalized radial coordinate + w: runtime_params_lib.TimeDependentField = 0.25 + # Source Gaussian central location (in normalized r) + rsource: runtime_params_lib.TimeDependentField = 0.0 + # total heating + Ptot: runtime_params_lib.TimeDependentField = 120e6 + # electron heating fraction + el_heat_fraction: runtime_params_lib.TimeDependentField = 0.66666 + + def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: + return DynamicRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + w: float + rsource: float + Ptot: float + el_heat_fraction: float def calc_generic_heat_source( @@ -50,10 +89,6 @@ def calc_generic_heat_source( source_el: source term for electrons. """ - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style - # pylint: disable=invalid-name - # calculate heat profile (face grid) Q = jnp.exp(-((geo.r_norm - rsource) ** 2) / (2 * w**2)) Q_face = jnp.exp(-((geo.r_face_norm - rsource) ** 2) / (2 * w**2)) @@ -68,25 +103,32 @@ def calc_generic_heat_source( def _default_formula( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> jnp.ndarray: """Returns the default formula-based ion/electron heat source profile.""" - del core_profiles # Unused. + del dynamic_config_slice, core_profiles # Unused. + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) ion, el = calc_generic_heat_source( geo, - dynamic_config_slice.rsource, - dynamic_config_slice.w, - dynamic_config_slice.Ptot, - dynamic_config_slice.el_heat_fraction, + dynamic_source_runtime_params.rsource, + dynamic_source_runtime_params.w, + dynamic_source_runtime_params.Ptot, + dynamic_source_runtime_params.el_heat_fraction, ) return jnp.stack([ion, el]) -@dataclasses.dataclass(frozen=True, kw_only=True) +# pylint: enable=invalid-name + + +@dataclasses.dataclass(kw_only=True) class GenericIonElectronHeatSource(source.IonElectronSource): """Generic heat source for both ion and electron heat.""" - name: str = 'generic_ion_el_heat_source' + runtime_params: RuntimeParams = dataclasses.field( + default_factory=RuntimeParams + ) - formula: source_config.SourceProfileFunction = _default_formula + formula: source.SourceProfileFunction = _default_formula diff --git a/torax/sources/ion_el_heat_sources.py b/torax/sources/ion_el_heat_sources.py index a0cfa402..13206863 100644 --- a/torax/sources/ion_el_heat_sources.py +++ b/torax/sources/ion_el_heat_sources.py @@ -30,71 +30,51 @@ # b) better readability and human-friendly error messages when debugging. -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class BremsstrahlungHeatSink(source.SingleProfileTempElSource): """Bremsstrahlung loss sink for the electron temp equation.""" - name: str = 'bremsstrahlung_heat_sink' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class ChargeExchangeHeatSink(source.SingleProfileTempIonSource): """Charge exchange loss term for the ion temp equation.""" - name: str = 'charge_exchange_heat_sink' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class CyclotronRadiationHeatSink(source.SingleProfileTempElSource): """Cyclotron radiation loss term for the electron temp equation.""" - name: str = 'cyclotron_radiation_heat_sink' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class ECRHHeatSource(source.SingleProfileTempElSource): """ECRH heat source for the electron temp equation.""" - name: str = 'ecrh_heat_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class ICRHHeatSource(source.SingleProfileTempIonSource): """ICRH heat source for the ion temp equation.""" - name: str = 'icrh_heat_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class LHHeatSource(source.SingleProfileTempElSource): """LH heat source for the electron temp equation.""" - name: str = 'lh_heat_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class LineRadiationHeatSink(source.SingleProfileTempElSource): """Line radiation loss sink for the electron temp equation.""" - name: str = 'line_radiation_heat_sink' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class NBIElectronHeatSource(source.SingleProfileTempElSource): """NBI heat source for the electron temp equation.""" - name: str = 'nbi_electron_heat_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class NBIIonHeatSource(source.SingleProfileTempIonSource): """NBI heat source for the ion temp equation.""" - name: str = 'nbi_ion_heat_source' - -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class RecombinationHeatSink(source.SingleProfileTempElSource): """Recombination loss sink for the electron temp equation.""" - - name: str = 'recombination_heat_sink' diff --git a/torax/sources/qei_source.py b/torax/sources/qei_source.py index c0fdabd3..3af886af 100644 --- a/torax/sources/qei_source.py +++ b/torax/sources/qei_source.py @@ -25,12 +25,36 @@ from torax import geometry from torax import physics from torax import state +from torax.runtime_params import config_slice_args +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config from torax.sources import source_profiles -@dataclasses.dataclass(frozen=True, kw_only=True) +# pylint: disable=invalid-name + + +@dataclasses.dataclass(kw_only=True) +class RuntimeParams(runtime_params_lib.RuntimeParams): + # multiplier for ion-electron heat exchange term for sensitivity testing + Qei_mult: float = 1.0 + + def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: + return DynamicRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + Qei_mult: float + + +@dataclasses.dataclass(kw_only=True) class QeiSource(source.Source): """Collisional ion-electron heat source. @@ -38,11 +62,13 @@ class QeiSource(source.Source): explicit terms in our solver. See sim.py for how this is used. """ - name: str = 'qei_source' + runtime_params: RuntimeParams = dataclasses.field( + default_factory=RuntimeParams + ) - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.MODEL_BASED, - source_config.SourceType.ZERO, + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.MODEL_BASED, + runtime_params_lib.Mode.ZERO, ) # Don't include affected_core_profiles in the __init__ arguments. @@ -60,18 +86,23 @@ class QeiSource(source.Source): def get_qei( self, - source_type: int, static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.QeiInfo: """Computes the value of the source.""" - source_type = self.check_source_type(source_type) + self.check_mode(dynamic_source_runtime_params.mode) return jax.lax.cond( - source_type == source_config.SourceType.MODEL_BASED.value, + dynamic_source_runtime_params.mode + == runtime_params_lib.Mode.MODEL_BASED.value, lambda: _model_based_qei( - static_config_slice, dynamic_config_slice, geo, core_profiles + static_config_slice, + dynamic_config_slice, + dynamic_source_runtime_params, + geo, + core_profiles, ), lambda: source_profiles.QeiInfo.zeros(geo), ) @@ -97,16 +128,18 @@ def get_source_profile_for_affected_core_profile( def _model_based_qei( static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.QeiInfo: """Computes Qei via the coll_exchange model.""" + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) zeros = jnp.zeros_like(geo.r_norm) qei_coef = physics.coll_exchange( core_profiles=core_profiles, - nref=dynamic_config_slice.nref, + nref=dynamic_config_slice.numerics.nref, Ai=dynamic_config_slice.plasma_composition.Ai, - Qei_mult=dynamic_config_slice.numerics.Qei_mult, + Qei_mult=dynamic_source_runtime_params.Qei_mult, ) implicit_ii = -qei_coef implicit_ee = -qei_coef @@ -136,3 +169,6 @@ def _model_based_qei( implicit_ie=implicit_ie, implicit_ei=implicit_ei, ) + + +# pylint: enable=invalid-name diff --git a/torax/sources/source_config.py b/torax/sources/runtime_params.py similarity index 53% rename from torax/sources/source_config.py rename to torax/sources/runtime_params.py index fb859948..a2ee75c2 100644 --- a/torax/sources/source_config.py +++ b/torax/sources/runtime_params.py @@ -14,32 +14,23 @@ """Configuration for all the sources/sinks modelled in Torax.""" -from collections.abc import Callable, Mapping +from __future__ import annotations + import dataclasses import enum -from typing import Any import chex +from torax import interpolated_param +from torax.runtime_params import config_slice_args from torax.sources import formula_config -# Sources implement these functions to be able to provide source profiles. The -# SourceConfig also gives a hook for users to provide a custom function. -# Using `Any` instead of the actual argument types below to avoid circular -# dependencies. -SourceProfileFunction = Callable[ - [ # Arguments - Any, # config.Config - Any, # geometry.Geometry - Any | None, # state.CoreProfiles - ], - # Returns a JAX array, tuple of arrays, or mapping of arrays. - chex.ArrayTree, -] +# Type-alias for clarity. +TimeDependentField = interpolated_param.InterpParamOrInterpParamInput @enum.unique -class SourceType(enum.Enum): +class Mode(enum.Enum): """Defines how to compute the source terms for this source/sink.""" # Source is set to zero always. This is an explicit source by definition. @@ -56,12 +47,12 @@ class SourceType(enum.Enum): @dataclasses.dataclass -class SourceConfig: +class RuntimeParams: """Configures a single source/sink term. This is a RUNTIME config, meaning its values can change from run to run without trigerring a recompile. This config defines the runtime config for the - entire simulation run. The DynamicSourceConfigSlice, which is derived from + entire simulation run. The DynamicRuntimeParams, which is derived from this class, only contains information for a single time step. Any compile-time configurations for the Sources should go into the Source @@ -69,7 +60,7 @@ class SourceConfig: """ # Defines how the source values are computed (from a model, from a file, etc.) - source_type: SourceType = SourceType.ZERO + mode: Mode = Mode.ZERO # Defines whether this is an explicit or implicit source. # Explicit sources are calculated based on the simulation state at the @@ -83,41 +74,33 @@ class SourceConfig: # running the simulation. is_explicit: bool = False + # Parameters used only when the source is using a prescribed formula to + # compute its profile. formula: formula_config.FormulaConfig = dataclasses.field( default_factory=formula_config.FormulaConfig ) - -# Define helper functions to use as factories in configs below. -# pylint: disable=g-long-lambda -get_model_based_source_config = lambda: SourceConfig( - source_type=SourceType.MODEL_BASED, -) -get_formula_based_source_config = lambda: SourceConfig( - source_type=SourceType.FORMULA_BASED, -) -# pylint: enable=g-long-lambda + def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: + return DynamicRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicRuntimeParams, + t=t, + ) + ) -def get_default_sources_config() -> Mapping[str, SourceConfig]: - """Returns a mapping of source names to their default runtime configurations. +@chex.dataclass(frozen=True) +class DynamicRuntimeParams: + """Dynamic params for a single TORAX source. - This makes an assumption about the names of Source objects used in the - simulation run, that they match the keys of the dictionary here. If that's not - the case, callers must modify the dictionary returned here. + These params can be changed without triggering a recompile. TORAX sources are + stateless, so these params are their inputs to determine their output + profiles. """ - return { - # Current sources (for psi equation) - 'j_bootstrap': get_model_based_source_config(), - 'jext': get_formula_based_source_config(), - # Electron density sources/sink (for the ne equation). - 'nbi_particle_source': get_formula_based_source_config(), - 'gas_puff_source': get_formula_based_source_config(), - 'pellet_source': get_formula_based_source_config(), - # Ion and electron heat sources (for the temp-ion and temp-el eqs). - 'generic_ion_el_heat_source': get_formula_based_source_config(), - 'fusion_heat_source': get_model_based_source_config(), - 'ohmic_heat_source': get_model_based_source_config(), - # NOTE: For qei_source, the is_explicit field in the config has no effect. - 'qei_source': get_model_based_source_config(), - } + + # This maps to the enum value for the Mode enum. The enum itself is not + # JAX-friendly. + mode: int + is_explicit: bool + formula: formula_config.DynamicFormula diff --git a/torax/sources/source.py b/torax/sources/source.py index c08a4142..94a00b3a 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -33,16 +33,20 @@ from torax import geometry from torax import jax_utils from torax import state -from torax.sources import source_config +from torax.sources import runtime_params as runtime_params_lib -def get_cell_profile_shape( - unused_config: config_slice.DynamicConfigSlice, - geo: geometry.Geometry, - unused_state: state.CoreProfiles | None, -): - """Returns the shape of a source profile on the cell grid.""" - return ProfileType.CELL.get_profile_shape(geo) +# Sources implement these functions to be able to provide source profiles. +SourceProfileFunction = Callable[ + [ # Arguments + config_slice.DynamicConfigSlice, # General config params. + runtime_params_lib.DynamicRuntimeParams, # Source-specific params. + geometry.Geometry, + state.CoreProfiles | None, + ], + # Returns a JAX array, tuple of arrays, or mapping of arrays. + chex.ArrayTree, +] # Any callable which takes the dynamic config, geometry, and optional core @@ -50,15 +54,20 @@ def get_cell_profile_shape( # source. See how these types of functions are used in the Source class below. SourceOutputShapeFunction = Callable[ [ # Arguments - config_slice.DynamicConfigSlice, geometry.Geometry, - state.CoreProfiles | None, ], # Returns shape of the source's output. tuple[int, ...], ] +def get_cell_profile_shape( + geo: geometry.Geometry, +): + """Returns the shape of a source profile on the cell grid.""" + return ProfileType.CELL.get_profile_shape(geo) + + @enum.unique class AffectedCoreProfile(enum.IntEnum): """Defines which part of the core profiles the source helps evolve. @@ -79,26 +88,26 @@ class AffectedCoreProfile(enum.IntEnum): TEMP_EL = 4 -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class Source: """Base class for a single source/sink term. Sources are used to compute source profiles (see source_profiles.py), which are in turn used to compute coeffs in sim.py. - NOTE: For most use cases, you should extend or use SingleProfileSource defined - below. + NOTE: For most use cases, you should extend or use SingleProfileSource. Attributes: - name: Name of this source. Used as a key to find this source's configuraiton - in the DynamicConfigSlice. Also used as a key for the output in the - SourceProfiles. + runtime_params: Input dataclass containing all the source-specific runtime + parameters. At runtime, the parameters here are interpolated to a specific + time t and then passed to the model_func or formula, depending on the mode + this source is running in. affected_core_profiles: Core profiles affected by this source's profile(s). This attribute defines which equations the source profiles are terms for. By default, the number of affected core profiles should equal the rank of the output shape returned by output_shape_getter. Subclasses may override this requirement. - supported_types: Defines how the source computes its profile. Can be set to + supported_modes: Defines how the source computes its profile. Can be set to zero, model-based, etc. At runtime, the input runtime config (the Config or the DynamicConfigSlice) will specify which supported type the Source is running with. If the runtime config specifies an unsupported type, an @@ -113,86 +122,82 @@ class Source: affected_core_profiles. Integer values of those enums. """ - name: str + affected_core_profiles: tuple[AffectedCoreProfile, ...] - # Defining a default here for the affected_core_profiles helps allow us to - # freeze the default in subclasses of Source. Without adding a default here, - # it isn't possible to add a default value in a child class AND hide it from - # the arguments of the subclasses's __init__ function. - # Similar logic holds for all the other attributes below. - affected_core_profiles: tuple[AffectedCoreProfile, ...] = ( - AffectedCoreProfile.NONE, + # Implementation detail: the DynamicConfigSliceProvider reads and interpolates + # these params via the SourceModels obj. This note is to help any code tracing + # someone might do if investigating how the parameters here are actually + # interpolated and packaged into the DynamicConfigSlice. + runtime_params: runtime_params_lib.RuntimeParams = dataclasses.field( + default_factory=runtime_params_lib.RuntimeParams ) - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.ZERO, - source_config.SourceType.FORMULA_BASED, + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, ) output_shape_getter: SourceOutputShapeFunction = get_cell_profile_shape - model_func: source_config.SourceProfileFunction | None = None + model_func: SourceProfileFunction | None = None - formula: source_config.SourceProfileFunction | None = None + formula: SourceProfileFunction | None = None @property def affected_core_profiles_ints(self) -> tuple[int, ...]: return tuple([int(cp) for cp in self.affected_core_profiles]) - def check_source_type( + def check_mode( self, - source_type: int | jnp.ndarray, + mode: int | jnp.ndarray, ) -> jnp.ndarray: """Raises an error if the source type is not supported.""" # This function is really just a wrapper around jax_utils.error_if with the # custom error message coming from this class. - source_type = jnp.array(source_type) - source_type = jax_utils.error_if( - source_type, - jnp.logical_not(self._is_type_supported(source_type)), - self._unsupported_type_error_msg(source_type), + mode = jnp.array(mode) + mode = jax_utils.error_if( + mode, + jnp.logical_not(self._is_type_supported(mode)), + self._unsupported_mode_error_msg(mode), ) - return source_type # pytype: disable=bad-return-type + return mode # pytype: disable=bad-return-type def _is_type_supported( self, - source_type: int | jnp.ndarray, + mode: int | jnp.ndarray, ) -> jnp.ndarray: """Returns whether the source type is supported.""" - source_type = jnp.array(source_type) + mode = jnp.array(mode) return jnp.any( jnp.bool_([ - supported_type.value == source_type - for supported_type in self.supported_types + supported_mode.value == mode + for supported_mode in self.supported_modes ]) ) - def _unsupported_type_error_msg( + def _unsupported_mode_error_msg( self, - source_type: source_config.SourceType | int | jnp.ndarray, + mode: runtime_params_lib.Mode | int | jnp.ndarray, ) -> str: return ( - f'{self.name} supports the following types: {self.supported_types}.' - f' Unsupported type provided: {source_type}.' + f'This source supports the following modes: {self.supported_modes}.' + f' Unsupported mode provided: {mode}.' ) def get_value( self, - source_type: int, # value of the source_config.SourceType enum. dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, ) -> chex.ArrayTree: """Returns the profile for this source during one time step. Args: - source_type: Method to use calculate the source profile (formula, model, - etc.). This integer should be the enum value of desired SourceType - instead of the actual enum because enums are not JAX-friendly. If the - input source type is not one of the object's supported types, this will - raise an error. dynamic_config_slice: Slice of the general TORAX config that can be used as input for this time step. + dynamic_source_runtime_params: Slice of this source's runtime parameters + at a specific time t. geo: Geometry of the torus. core_profiles: Core plasma profiles. May be the profiles at the start of the time step or a "live" set of core profiles being actively updated @@ -204,25 +209,23 @@ def get_value( Returns: Array, arrays, or nested dataclass/dict of arrays for the source profile. """ - source_type = self.check_source_type(source_type) - output_shape = self.output_shape_getter( - dynamic_config_slice, geo, core_profiles - ) + self.check_mode(dynamic_source_runtime_params.mode) + output_shape = self.output_shape_getter(geo) model_func = ( - (lambda _0, _1, _2: jnp.zeros(output_shape)) + (lambda _0, _1, _2, _3: jnp.zeros(output_shape)) if self.model_func is None else self.model_func ) formula = ( - (lambda _0, _1, _2: jnp.zeros(output_shape)) + (lambda _0, _1, _2, _3: jnp.zeros(output_shape)) if self.formula is None else self.formula ) return get_source_profiles( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, - source_type=source_type, model_func=model_func, formula=formula, output_shape=output_shape, @@ -277,7 +280,7 @@ def get_source_profile_for_affected_core_profile( ) -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class SingleProfileSource(Source): """Source providing a single output profile on the cell grid. @@ -288,66 +291,85 @@ class SingleProfileSource(Source): ```python # Define an electron-density source with a Gaussian profile. - my_custom_source_name = 'custom_ne_source' my_custom_source = source.SingleProfileSource( - name=my_custom_source_name, - supported_types=( - source_config.SourceType.ZERO, - source_config.SourceType.FORMULA_BASED, + supported_modes=( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, ), affected_core_profiles=[source.AffectedCoreProfile.NE], formula=formulas.Gaussian(my_custom_source_name), ) - all_torax_sources = source_models_lib.SourceModels( - additional_sources=[ - my_custom_source, - ] - ) - ``` - - You must also include a runtime config for the custom source: - - ```python - my_torax_config = config.Config( - sources=dict( - ... # Configs for other sources. - # Set some params for the new source - custom_ne_source=source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - formula=formula_config.FormulaConfig( - gaussian=formula_config.Gaussian( - total=1.0, - c1=2.0, - c2=3.0, - ), - ), - ), + # Define its runtime parameters (this could be done in the constructor as + # well). + my_custom_source.runtime_params = runtime_params_lib.RuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + formula=formula_config.Gaussian( + total=1.0, + c1=2.0, + c2=3.0, ), ) + all_torax_sources = source_models_lib.SourceModels( + sources={ + 'my_custom_source': my_custom_source, + } + ) ``` If you want to create a subclass of SingleProfileSource with frozen parameters, you can provide default implementations/attributes. This is an example of a model-based source with a frozen custom model that cannot be - changed by a config: + changed by a config, along with custom runtime parameters specific to this + source: ```python + @dataclasses.dataclass(kw_only=True) + class FooRuntimeParams(runtime_params_lib.RuntimeParams): + foo_param: runtime_params_lib.TimeDependentField + bar_param: float + + def build_dynamic_params(self, t: chex.Numeric) -> DynamicFooRuntimeParams: + return DynamicFooRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=DynamicFooRuntimeParams, + t=t, + ) + ) + + @chex.dataclass(frozen=True) + class DynamicFooRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + foo_param: float + bar_param: float - def _my_foo_model(dynamic_config_slice, geo, core_profiles) -> jnp.ndarray: + def _my_foo_model( + dynamic_config_slice, + dynamic_source_runtime_params, + geo, + core_profiles, + ) -> jnp.ndarray: + assert isinstance(dynamic_source_runtime_params, DynamicFooRuntimeParams) # implement your foo model. + @dataclasses.dataclass(kw_only=True) class FooSource(SingleProfileSource): - name: str = 'foo_source' # the default name for this source. + # Provide a default set of params. + runtime_params: FooRuntimeParams = dataclasses.field( + default_factory=lambda: FooRuntimeParams( + foo_param={0.0: 10.0, 1.0: 20.0, 2.0: 35.0}, + bar_param: 1.234, + ) + ) # By default, FooSource's can be model-based or set to 0. - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.ZERO, - source_config.SourceType.MODEL_BASED, + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.MODEL_BASED, ) # Don't include model_func in the __init__ arguments and freeze it. - model_func: source_config.SourceProfileFunction = dataclasses.field( + model_func: SourceProfileFunction = dataclasses.field( init=False, default_factory=lambda: _my_foo_model, ) @@ -363,20 +385,18 @@ class FooSource(SingleProfileSource): def get_value( self, - source_type: int, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, ) -> jnp.ndarray: """Returns the profile for this source during one time step.""" - output_shape = self.output_shape_getter( - dynamic_config_slice, geo, core_profiles - ) + output_shape = self.output_shape_getter(geo) profile = super().get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, - source_type=source_type, ) assert isinstance(profile, jnp.ndarray) chex.assert_rank(profile, 1) @@ -419,23 +439,24 @@ def get_zero_profile(self, geo: geometry.Geometry) -> jnp.ndarray: def get_source_profiles( - source_type: int | jnp.ndarray, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None, - model_func: source_config.SourceProfileFunction, - formula: source_config.SourceProfileFunction, + model_func: SourceProfileFunction, + formula: SourceProfileFunction, output_shape: tuple[int, ...], ) -> jnp.ndarray: - """Returns source profiles requested by the source_config. + """Returns source profiles requested by the runtime_params_lib. This function handles MODEL_BASED, FORMULA_BASED, and ZERO sources. All other source types will be ignored. Args: - source_type: Method to use to get the source profile. dynamic_config_slice: Slice of the general TORAX config that can be used as input for this time step. + dynamic_source_runtime_params: Slice of this source's runtime parameters at + a specific time t. geo: Geometry information. Used as input to the source profile functions. core_profiles: Core plasma profiles. Used as input to the source profile functions. @@ -446,16 +467,27 @@ def get_source_profiles( Returns: Output array of a profile or concatenated/stacked profiles. """ + mode = dynamic_source_runtime_params.mode zeros = jnp.zeros(output_shape) output = jnp.zeros(output_shape) output += jnp.where( - source_type == source_config.SourceType.MODEL_BASED.value, - model_func(dynamic_config_slice, geo, core_profiles), + mode == runtime_params_lib.Mode.MODEL_BASED.value, + model_func( + dynamic_config_slice, + dynamic_source_runtime_params, + geo, + core_profiles, + ), zeros, ) output += jnp.where( - source_type == source_config.SourceType.FORMULA_BASED.value, - formula(dynamic_config_slice, geo, core_profiles), + mode == runtime_params_lib.Mode.FORMULA_BASED.value, + formula( + dynamic_config_slice, + dynamic_source_runtime_params, + geo, + core_profiles, + ), zeros, ) return output @@ -465,55 +497,43 @@ def get_source_profiles( # sources defined in the other files in this folder. -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class SingleProfilePsiSource(SingleProfileSource): - # Don't include affected_core_profiles in the __init__ arguments. - # Freeze this param. - affected_core_profiles: tuple[AffectedCoreProfile, ...] = dataclasses.field( - init=False, - default=(AffectedCoreProfile.PSI,), + affected_core_profiles: tuple[AffectedCoreProfile, ...] = ( + AffectedCoreProfile.PSI, ) -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class SingleProfileNeSource(SingleProfileSource): - # Don't include affected_core_profiles in the __init__ arguments. - # Freeze this param. - affected_core_profiles: tuple[AffectedCoreProfile, ...] = dataclasses.field( - init=False, - default=(AffectedCoreProfile.NE,), + affected_core_profiles: tuple[AffectedCoreProfile, ...] = ( + AffectedCoreProfile.NE, ) -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class SingleProfileTempIonSource(SingleProfileSource): - # Don't include affected_core_profiles in the __init__ arguments. - # Freeze this param. - affected_core_profiles: tuple[AffectedCoreProfile, ...] = dataclasses.field( - init=False, - default=(AffectedCoreProfile.TEMP_ION,), + affected_core_profiles: tuple[AffectedCoreProfile, ...] = ( + AffectedCoreProfile.TEMP_ION, ) -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class SingleProfileTempElSource(SingleProfileSource): - # Don't include affected_core_profiles in the __init__ arguments. - # Freeze this param. - affected_core_profiles: tuple[AffectedCoreProfile, ...] = dataclasses.field( - init=False, - default=(AffectedCoreProfile.TEMP_EL,), + affected_core_profiles: tuple[AffectedCoreProfile, ...] = ( + AffectedCoreProfile.TEMP_EL, ) -def _get_ion_el_output_shape(unused_config, geo, unused_state): +def _get_ion_el_output_shape(geo): return (2,) + ProfileType.CELL.get_profile_shape(geo) -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class IonElectronSource(Source): """Base class for a source/sink that can be used for both ions / electrons. @@ -528,19 +548,16 @@ class IonElectronSource(Source): first being ion profile and the second being the electron profile. """ - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.FORMULA_BASED, - source_config.SourceType.ZERO, + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.FORMULA_BASED, + runtime_params_lib.Mode.ZERO, ) # Don't include affected_core_profiles in the __init__ arguments. # Freeze this param. - affected_core_profiles: tuple[AffectedCoreProfile, ...] = dataclasses.field( - init=False, - default=( - AffectedCoreProfile.TEMP_ION, - AffectedCoreProfile.TEMP_EL, - ), + affected_core_profiles: tuple[AffectedCoreProfile, ...] = ( + AffectedCoreProfile.TEMP_ION, + AffectedCoreProfile.TEMP_EL, ) # Don't include output_shape_getter in the __init__ arguments. @@ -552,19 +569,18 @@ class IonElectronSource(Source): def get_value( self, - source_type: int, dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles | None = None, ) -> jnp.ndarray: """Computes the ion and electron values of the source. Args: - source_type: Method to use calculate the source profile (formula, model, - etc.). This is the enum value of SourceType instead of the actual enum - instance because enums aren't JAX-friendly. dynamic_config_slice: Input config which can change from time step to time step. + dynamic_source_runtime_params: Slice of this source's runtime parameters + at a specific time t. geo: Geometry of the torus. core_profiles: Core plasma profiles used to compute the source's profiles. @@ -572,12 +588,10 @@ def get_value( 2 stacked arrays, the first for the ion profile and the second for the electron profile. """ - output_shape = self.output_shape_getter( - dynamic_config_slice, geo, core_profiles - ) + output_shape = self.output_shape_getter(geo) profile = super().get_value( - source_type=source_type, dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, ) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 2bf5a740..15af48ca 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -28,13 +28,10 @@ from torax import state from torax.fvm import diffusion_terms from torax.sources import bootstrap_current_source -from torax.sources import electron_density_sources from torax.sources import external_current_source -from torax.sources import fusion_heat_source as fusion_heat_source_lib -from torax.sources import generic_ion_el_heat_source as generic_ion_el_heat_source_lib from torax.sources import qei_source as qei_source_lib +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_config from torax.sources import source_profiles @@ -70,14 +67,16 @@ def build_source_profiles( """ # Bootstrap current is a special-case source with multiple outputs, so handle # it here. - # TODO(b/314308399): Add a new neoclassical directory with - # different ways to compute sigma and bootstrap current. + dynamic_bootstrap_runtime_params = dynamic_config_slice.sources[ + source_models.j_bootstrap_name + ] bootstrap_profiles = _build_bootstrap_profiles( - dynamic_config_slice, - geo, - core_profiles, - source_models.j_bootstrap, - explicit, + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, + geo=geo, + core_profiles=core_profiles, + j_bootstrap_source=source_models.j_bootstrap, + explicit=explicit, ) other_profiles = {} other_profiles.update( @@ -110,6 +109,7 @@ def build_source_profiles( def _build_bootstrap_profiles( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, j_bootstrap_source: bootstrap_current_source.BootstrapCurrentSource, @@ -121,6 +121,8 @@ def _build_bootstrap_profiles( Args: dynamic_config_slice: Input config for this time step. Can change from time step to time step. + dynamic_source_runtime_params: Input runtime parameters for this time step, + specific to the bootstrap current source. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if @@ -138,13 +140,13 @@ def _build_bootstrap_profiles( """ bootstrap_profile = j_bootstrap_source.get_value( dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, core_profiles=core_profiles, ) sigma = jax_utils.select( jnp.logical_or( - explicit - == dynamic_config_slice.sources[j_bootstrap_source.name].is_explicit, + explicit == dynamic_source_runtime_params.is_explicit, calculate_anyway, ), bootstrap_profile.sigma, @@ -152,8 +154,7 @@ def _build_bootstrap_profiles( ) j_bootstrap = jax_utils.select( jnp.logical_or( - explicit - == dynamic_config_slice.sources[j_bootstrap_source.name].is_explicit, + explicit == dynamic_source_runtime_params.is_explicit, calculate_anyway, ), bootstrap_profile.j_bootstrap, @@ -161,8 +162,7 @@ def _build_bootstrap_profiles( ) j_bootstrap_face = jax_utils.select( jnp.logical_or( - explicit - == dynamic_config_slice.sources[j_bootstrap_source.name].is_explicit, + explicit == dynamic_source_runtime_params.is_explicit, calculate_anyway, ), bootstrap_profile.j_bootstrap_face, @@ -170,8 +170,7 @@ def _build_bootstrap_profiles( ) I_bootstrap = jax_utils.select( # pylint: disable=invalid-name jnp.logical_or( - explicit - == dynamic_config_slice.sources[j_bootstrap_source.name].is_explicit, + explicit == dynamic_source_runtime_params.is_explicit, calculate_anyway, ), bootstrap_profile.I_bootstrap, @@ -214,11 +213,13 @@ def _build_psi_profiles( dict of psi source profiles. """ psi_profiles = {} - # jext is precomputed in the initial core profiles. - psi_profiles[source_models.jext.name] = jax_utils.select( + # jext is precomputed in the core profiles. + dynamic_jext_runtime_params = dynamic_config_slice.sources[ + source_models.jext_name + ] + psi_profiles[source_models.jext_name] = jax_utils.select( jnp.logical_or( - explicit - == dynamic_config_slice.sources[source_models.jext.name].is_explicit, + explicit == dynamic_jext_runtime_params.is_explicit, calculate_anyway, ), core_profiles.currents.jext, @@ -227,14 +228,15 @@ def _build_psi_profiles( # Iterate through the rest of the sources and compute profiles for the ones # which relate to psi. jext is not part of the "standard sources." for source_name, source in source_models.psi_sources.items(): - dynamic_source_config = dynamic_config_slice.sources[source_name] + dynamic_source_runtime_params = dynamic_config_slice.sources[source_name] psi_profiles[source_name] = jax_utils.select( jnp.logical_or( - explicit == dynamic_source_config.is_explicit, calculate_anyway + explicit == dynamic_source_runtime_params.is_explicit, + calculate_anyway, ), source.get_value( - dynamic_source_config.source_type, dynamic_config_slice, + dynamic_source_runtime_params, geo, core_profiles, ), @@ -273,12 +275,12 @@ def _build_ne_profiles( # Iterate through the sources and compute profiles for the ones which relate # to ne. for source_name, source in source_models.ne_sources.items(): - dynamic_source_config = dynamic_config_slice.sources[source_name] + dynamic_source_runtime_params = dynamic_config_slice.sources[source_name] ne_profiles[source_name] = jax_utils.select( - explicit == dynamic_source_config.is_explicit, + explicit == dynamic_source_runtime_params.is_explicit, source.get_value( - dynamic_source_config.source_type, dynamic_config_slice, + dynamic_source_runtime_params, geo, core_profiles, ), @@ -319,15 +321,13 @@ def _build_temp_ion_el_profiles( source_models.temp_ion_sources | source_models.temp_el_sources ) for source_name, source in temp_ion_el_sources.items(): - zeros = jnp.zeros( - source.output_shape_getter(dynamic_config_slice, geo, core_profiles) - ) - dynamic_source_config = dynamic_config_slice.sources[source_name] + zeros = jnp.zeros(source.output_shape_getter(geo)) + dynamic_source_runtime_params = dynamic_config_slice.sources[source_name] ion_el_profiles[source_name] = jax_utils.select( - explicit == dynamic_source_config.is_explicit, + explicit == dynamic_source_runtime_params.is_explicit, source.get_value( - dynamic_source_config.source_type, dynamic_config_slice, + dynamic_source_runtime_params, geo, core_profiles, ), @@ -344,7 +344,7 @@ def sum_sources_psi( """Computes psi source values for sim.calc_coeffs.""" total = ( source_profile.j_bootstrap.j_bootstrap - + source_profile.profiles[source_models.jext.name] + + source_profile.profiles[source_models.jext_name] ) for source_name, source in source_models.psi_sources.items(): total += source.get_source_profile_for_affected_core_profile( @@ -426,11 +426,15 @@ def calc_and_sum_sources_psi( total = 0 for key in psi_profiles: total += psi_profiles[key] + dynamic_bootstrap_runtime_params = dynamic_config_slice.sources[ + source_models.j_bootstrap_name + ] j_bootstrap_profiles = _build_bootstrap_profiles( - dynamic_config_slice, - geo, - core_profiles, - source_models.j_bootstrap, + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, + geo=geo, + core_profiles=core_profiles, + j_bootstrap_source=source_models.j_bootstrap, calculate_anyway=True, ) total += j_bootstrap_profiles.j_bootstrap @@ -522,21 +526,36 @@ def _ohmic_heat_model( return pohm -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class OhmicHeatSource(source_lib.SingleProfileSource): """Ohmic heat source for electron heat equation. - Pohm = jtor * psidot /(2*pi*Rmaj), related to electric power formula P = IV + Pohm = jtor * psidot /(2*pi*Rmaj), related to electric power formula P = IV. + + Because this source requires access to the rest of the Sources, it must be + added to the SourceModels object after creation: + + ```python + source_models = SourceModels(sources={...}) + # Now add the ohmic heat source and turn it on. + source_models.add_source( + source_name='ohmic_heat_source', + source=OhmicHeatSource( + source_models=source_models, + runtime_params=runtime_params.RuntimeParams( + mode=runtime_params.Mode.MODEL_BASED, # turns the source on. + ), + ), + ) + ``` """ # Users must pass in a pointer to the complete set of sources to this object. source_models: SourceModels - name: str = 'ohmic_heat_source' - - supported_types: tuple[source_config.SourceType, ...] = ( - source_config.SourceType.ZERO, - source_config.SourceType.MODEL_BASED, + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.MODEL_BASED, ) # Freeze these params and do not include them in the __init__. @@ -546,7 +565,7 @@ class OhmicHeatSource(source_lib.SingleProfileSource): default=(source_lib.AffectedCoreProfile.TEMP_EL,), ) ) - model_func: source_config.SourceProfileFunction | None = dataclasses.field( + model_func: source_lib.SourceProfileFunction | None = dataclasses.field( init=False, default_factory=lambda: None, # ignored. ) @@ -555,9 +574,11 @@ def __post_init__(self): # Ignore the model provided above and set it to the function here. def _model_func( dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> jnp.ndarray: + del dynamic_source_runtime_params return _ohmic_heat_model( dynamic_config_slice=dynamic_config_slice, geo=geo, @@ -565,13 +586,7 @@ def _model_func( source_models=self.source_models, ) - # Must use object.__setattr__ instead of simply doing - # self.model_func = _model_func - # because this class is a frozen dataclass. Frozen classes cannot set any - # self attributes after init, but this is a workaround. We cannot set the - # model_func in the dataclass field above either because we need access to - # self in the implementation. - object.__setattr__(self, 'model_func', _model_func) + self.model_func = _model_func class SourceModels: @@ -591,228 +606,180 @@ class SourceModels: shows how to define a new custom electron-density source. ```python - # Define an electron-density source with a Gaussian profile. - my_custom_source_name = 'custom_ne_source' + # Define an electron-density source with a time-dependent Gaussian profile. my_custom_source = source.SingleProfileSource( - name=my_custom_source_name, - supported_types=( - source_config.SourceType.ZERO, - source_config.SourceType.FORMULA_BASED, + supported_modes=( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, ), affected_core_profiles=source.AffectedCoreProfile.NE, - formula=formulas.Gaussian(my_custom_source_name), - ) - all_torax_sources = source_models_lib.SourceModels( - additional_sources=[ - my_custom_source, - ] - ) - ``` - - You must also include a runtime config for the new custom source: - - ```python - my_torax_config = config.Config( - sources=dict( - ... # Configs for other sources. - # Set some params for the new source - custom_ne_source=source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - formula=formula_config.FormulaConfig( - gaussian=formula_config.Gaussian( - total=1.0, - c1=2.0, - c2=3.0, - ), - ), + formula=formulas.Gaussian(), + # Define (possibly) time-dependent parameters to feed to the formula. + runtime_params=runtime_params_lib.RuntimeParams( + formula=formula_config.Gaussian( + total={0.0: 1.0, 5.0: 2.0, 10.0: 1.0}, # time-dependent. + c1=2.0, + c2=3.0, ), ), ) + # Define the collection of sources here, which in this example only includes + # one source. + all_torax_sources = source_models_lib.SourceModels( + sources={'my_custom_source': my_custom_source} + ) ``` - See source_config.py for more details on how to configure all the source/sink + See runtime_params.py for more details on how to configure all the source/sink terms. """ def __init__( self, - *, - # All arguments must be provided as keyword arguments to ensure that - # everything is set explicitly. Helps avoid unwarranted mistakes. - # The sources below are on by default, which is why they are exposed - # directly in the constructor. - # The sources listed below are the default sources that are turned on as - # well by default. - # Current sources (for psi equation) - j_bootstrap: ( - bootstrap_current_source.BootstrapCurrentSource | None - ) = None, - jext: external_current_source.ExternalCurrentSource | None = None, - # Electron density sources/sink (for the ne equation). - gas_puff_source: electron_density_sources.GasPuffSource | None = None, - nbi_particle_source: ( - electron_density_sources.NBIParticleSource | None - ) = None, - pellet_source: electron_density_sources.PelletSource | None = None, - # Ion and electron heat sources (for the temp-ion and temp-el eqs). - generic_ion_el_heat_source: ( - generic_ion_el_heat_source_lib.GenericIonElectronHeatSource | None - ) = None, - fusion_heat_source: fusion_heat_source_lib.FusionHeatSource | None = None, - ohmic_heat_source: OhmicHeatSource | None = None, - qei_source: qei_source_lib.QeiSource | None = None, - # Any additional sources that the user wants to provide. - additional_sources: list[source_lib.Source] | None = None, + sources: dict[str, source_lib.Source] | None = None, ): """Constructs a collection of sources. This class defines which sources are available in a TORAX simulation run. Users can configure whether each source is actually on and what kind of profile it produces by changing its runtime configuration (see - source_config.py). - - Some TORAX sources are required and on by default. These sources are in the - argument list of this `__init__()` function. While these sources are on by - default, they can be turned off by setting the source to ZERO. - - For example, to turn off the gas-puff source: - - ```python - sources = source_models_lib.SourceModels() - my_torax_config = config.Config( - sources=dict( - gas_puff_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), - ) - ``` + runtime_params_lib.py). Args: - j_bootstrap: Bootstrap current density source for the psi equation. Is a - "neoclassical" source. - jext: External current density source for the psi equation. - gas_puff_source: Gas puff particle source for the electron density ne - equation. - nbi_particle_source: Neutral beam injection particle source for the - electron density ne equation. - pellet_source: Pellet source for the electron density ne equation. - generic_ion_el_heat_source: Generic heat source coupled for both the ion - and electron heat equations. - fusion_heat_source: Alpha heat source for coupled for both the ion and - electron heat equations. - ohmic_heat_source: Ohmic heating for electron temperatures. - qei_source: Collisional ion-electron heat source. Special-case source used - in both the explicit and implicit terms in the TORAX solver. - additional_sources: Optional list of additional sources to include in - TORAX. Remember that all additional sources need their corresponding - runtime config to be included in config.Config(). All these additional - sources are "standard" sources (they are not going to be treated as - special cases like j_bootstrap, jext, and qei_source are). They will be - accessible via the standard_sources property. + sources: Mapping of source model names to the Source objects. The names + (i.e. the keys of this dictionary) also define the keys in the output + SourceProfiles which are computed from this SourceModels object. NOTE - + Some sources are "special-case": bootstrap current, external current, + and Qei. SourceModels will always instantiate default objects for these + types of sources unless they are provided by this `sources` argument. + Also, their default names are reserved, meaning the input dictionary + `sources` should not have the keys 'j_bootstrap', 'jext', or + 'qei_source' unless those sources are one of these "special-case" + sources. + + Raises: + ValueError if there is a naming collision with the reserved names as + described above. """ - self._j_bootstrap = ( - bootstrap_current_source.BootstrapCurrentSource() - if j_bootstrap is None - else j_bootstrap - ) - self._qei_source = ( - qei_source_lib.QeiSource() if qei_source is None else qei_source - ) - - self._jext = ( - external_current_source.ExternalCurrentSource() - if jext is None - else jext - ) - gas_puff_source = ( - electron_density_sources.GasPuffSource() - if gas_puff_source is None - else gas_puff_source - ) - nbi_particle_source = ( - electron_density_sources.NBIParticleSource() - if nbi_particle_source is None - else nbi_particle_source - ) - pellet_source = ( - electron_density_sources.PelletSource() - if pellet_source is None - else pellet_source - ) - generic_ion_el_heat_source = ( - generic_ion_el_heat_source_lib.GenericIonElectronHeatSource() - if generic_ion_el_heat_source is None - else generic_ion_el_heat_source - ) - fusion_heat_source = ( - fusion_heat_source_lib.FusionHeatSource() - if fusion_heat_source is None - else fusion_heat_source - ) - ohmic_heat_source = ( - OhmicHeatSource(source_models=self) - if ohmic_heat_source is None - else ohmic_heat_source - ) - additional_sources = ( - [] if additional_sources is None else additional_sources - ) - - # All sources which are "standard" and can be accessed as - # source_lib.Source objects when computing profiles. - self._standard_sources: dict[str, source_lib.Source] = dict( - gas_puff_source=gas_puff_source, - nbi_particle_source=nbi_particle_source, - pellet_source=pellet_source, - generic_ion_el_heat_source=generic_ion_el_heat_source, - fusion_heat_source=fusion_heat_source, - ohmic_heat_source=ohmic_heat_source, - ) - for additional_source in additional_sources: - self._standard_sources[additional_source.name] = additional_source + sources = sources or {} + # Some sources are accessed for specific use cases, so we extract those + # ones and expose them directly. + self._j_bootstrap = None + self._j_bootstrap_name = 'j_bootstrap' # default, can be overridden below. + self._jext = None + self._jext_name = 'jext' # default, can be overridden below. + self._qei_source = None + self._qei_source_name = 'qei_source' # default, can be overridden below. + # The rest of the sources are "standard". + self._standard_sources = {} + + # Divide up the sources based on which core profiles they affect. self._psi_sources: dict[str, source_lib.Source] = {} self._ne_sources: dict[str, source_lib.Source] = {} self._temp_ion_sources: dict[str, source_lib.Source] = {} self._temp_el_sources: dict[str, source_lib.Source] = {} - for source_name, source in self._standard_sources.items(): - if source_lib.AffectedCoreProfile.PSI in source.affected_core_profiles: - self._psi_sources[source_name] = source - if source_lib.AffectedCoreProfile.NE in source.affected_core_profiles: - self._ne_sources[source_name] = source - if ( - source_lib.AffectedCoreProfile.TEMP_ION - in source.affected_core_profiles - ): - self._temp_ion_sources[source_name] = source - if ( - source_lib.AffectedCoreProfile.TEMP_EL - in source.affected_core_profiles - ): - self._temp_el_sources[source_name] = source - - self._all_sources = self._standard_sources | { - self._j_bootstrap.name: self._j_bootstrap, - self._jext.name: self._jext, - self._qei_source.name: self._qei_source, - } + for source_name, source in sources.items(): + if isinstance(source, bootstrap_current_source.BootstrapCurrentSource): + self._j_bootstrap_name = source_name + self._j_bootstrap = source + elif isinstance(source, external_current_source.ExternalCurrentSource): + self._jext_name = source_name + self._jext = source + elif isinstance(source, qei_source_lib.QeiSource): + self._qei_source_name = source_name + self._qei_source = source + else: + self.add_source(source_name, source) + + # Make sure defaults are set. + if self._j_bootstrap is None: + self._j_bootstrap = bootstrap_current_source.BootstrapCurrentSource() + if self._jext is None: + self._jext = external_current_source.ExternalCurrentSource() + if self._qei_source is None: + self._qei_source = qei_source_lib.QeiSource() + + def add_source( + self, + source_name: str, + source: source_lib.Source, + ) -> None: + """Adds a source to the collection of sources. + + Do NOT directly add new sources to `SourceModels.standard_sources`. Users + should call this function instead. Cannot add additional bootstrap current, + external current, or Qei sources - those must be defined in the __init__. + + Args: + source_name: Name of the new source being added. This will be the key + under which the source's output profile will be found in the output + SourceProfiles object. + source: The new standard source being added. + + Raises: + ValueError if a "special-case" source is provided. + """ + if ( + isinstance(source, bootstrap_current_source.BootstrapCurrentSource) + or isinstance(source, external_current_source.ExternalCurrentSource) + or isinstance(source, qei_source_lib.QeiSource) + ): + raise ValueError( + 'Cannot add a source with the following types: ' + 'bootstrap_current_source.BootstrapCurrentSource,' + ' external_current_source.ExternalCurrentSource, or' + ' qei_source_lib.QeiSource.' + ) + reserved_names = [ + self._j_bootstrap_name, + self._jext_name, + self._qei_source_name, + ] + if source_name in reserved_names: + raise ValueError( + f'Cannot add a source with one of these names: {reserved_names}.' + ) + self._standard_sources[source_name] = source + if source_lib.AffectedCoreProfile.PSI in source.affected_core_profiles: + self._psi_sources[source_name] = source + if source_lib.AffectedCoreProfile.NE in source.affected_core_profiles: + self._ne_sources[source_name] = source + if source_lib.AffectedCoreProfile.TEMP_ION in source.affected_core_profiles: + self._temp_ion_sources[source_name] = source + if source_lib.AffectedCoreProfile.TEMP_EL in source.affected_core_profiles: + self._temp_el_sources[source_name] = source # Some sources require direct access, so this class defines properties for # those sources. @property def j_bootstrap(self) -> bootstrap_current_source.BootstrapCurrentSource: + assert self._j_bootstrap is not None return self._j_bootstrap + @property + def j_bootstrap_name(self) -> str: + return self._j_bootstrap_name + @property def jext(self) -> external_current_source.ExternalCurrentSource: + assert self._jext is not None return self._jext + @property + def jext_name(self) -> str: + return self._jext_name + @property def qei_source(self) -> qei_source_lib.QeiSource: + assert self._qei_source is not None return self._qei_source + @property + def qei_source_name(self) -> str: + return self._qei_source_name + @property def psi_sources(self) -> dict[str, source_lib.Source]: return self._psi_sources @@ -834,7 +801,7 @@ def ion_el_sources(self) -> dict[str, source_lib.Source]: """Returns all source models which output both ion and el temp profiles.""" return { name: source - for name, source in self.standard_sources.items() + for name, source in self._standard_sources.items() if source.affected_core_profiles == ( source_lib.AffectedCoreProfile.TEMP_ION, @@ -852,23 +819,32 @@ def standard_sources(self) -> dict[str, source_lib.Source]: return self._standard_sources @property - def all_sources(self) -> dict[str, source_lib.Source]: - return self._all_sources + def sources(self) -> dict[str, source_lib.Source]: + return self._standard_sources | { + self._j_bootstrap_name: self.j_bootstrap, + self._jext_name: self.jext, + self._qei_source_name: self.qei_source, + } + + @property + def runtime_params(self) -> dict[str, runtime_params_lib.RuntimeParams]: + """Returns all the runtime params for all sources.""" + return { + source_name: source.runtime_params + for source_name, source in self.sources.items() + } def build_all_zero_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, source_models: SourceModels, ) -> source_profiles.SourceProfiles: """Returns a SourceProfiles object with all zero profiles.""" profiles = { - source_name: jnp.zeros( - source_model.output_shape_getter(dynamic_config_slice, geo, None) - ) + source_name: jnp.zeros(source_model.output_shape_getter(geo)) for source_name, source_model in source_models.standard_sources.items() } - profiles[source_models.jext.name] = jnp.zeros_like(geo.r) + profiles[source_models.jext_name] = jnp.zeros_like(geo.r) return source_profiles.SourceProfiles( profiles=profiles, j_bootstrap=source_profiles.BootstrapCurrentProfile.zero_profile(geo), diff --git a/torax/sources/tests/bootstrap_current_source.py b/torax/sources/tests/bootstrap_current_source.py index ff0b2ce7..a1fe44c6 100644 --- a/torax/sources/tests/bootstrap_current_source.py +++ b/torax/sources/tests/bootstrap_current_source.py @@ -22,8 +22,8 @@ from torax import core_profile_setters from torax import geometry from torax.sources import bootstrap_current_source +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_config from torax.sources import source_models as source_models_lib from torax.sources import source_profiles from torax.sources.tests import test_lib @@ -36,9 +36,8 @@ class BootstrapCurrentSourceTest(test_lib.SourceTestCase): def setUpClass(cls): super().setUpClass( source_class=bootstrap_current_source.BootstrapCurrentSource, - unsupported_types=[ - source_config.SourceType.FORMULA_BASED, - source_config.SourceType.ZERO, + unsupported_modes=[ + runtime_params_lib.Mode.FORMULA_BASED, ], expected_affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), ) @@ -47,17 +46,26 @@ def test_source_value(self): source = bootstrap_current_source.BootstrapCurrentSource() config = config_lib.Config() geo = geometry.build_circular_geometry(config) + source_models = source_models_lib.SourceModels( + sources={'j_bootstrap': source} + ) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), - static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=dynamic_config_slice, + static_config_slice=static_config_slice, geo=geo, - source_models=source_models_lib.SourceModels(j_bootstrap=source), + source_models=source_models, ) self.assertIsNotNone( source.get_value( - dynamic_config_slice=( - config_slice.build_dynamic_config_slice(config) - ), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources[ + source_models.j_bootstrap_name + ], geo=geo, temp_ion=core_profiles.temp_ion, temp_el=core_profiles.temp_el, diff --git a/torax/sources/tests/current_density_sources.py b/torax/sources/tests/current_density_sources.py index 88a2f76a..b4c53605 100644 --- a/torax/sources/tests/current_density_sources.py +++ b/torax/sources/tests/current_density_sources.py @@ -16,8 +16,8 @@ from absl.testing import absltest from torax.sources import current_density_sources as cds +from torax.sources import runtime_params from torax.sources import source -from torax.sources import source_config from torax.sources.tests import test_lib @@ -28,8 +28,8 @@ class ECRHCurrentSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=cds.ECRHCurrentSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.PSI,), ) @@ -42,8 +42,8 @@ class ICRHCurrentSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=cds.ICRHCurrentSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.PSI,), ) @@ -56,8 +56,8 @@ class LHCurrentSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=cds.LHCurrentSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.PSI,), ) @@ -70,8 +70,8 @@ class NBICurrentSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=cds.NBICurrentSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.PSI,), ) diff --git a/torax/sources/tests/electron_density_sources.py b/torax/sources/tests/electron_density_sources.py index 8b06bc2b..fff22eda 100644 --- a/torax/sources/tests/electron_density_sources.py +++ b/torax/sources/tests/electron_density_sources.py @@ -16,8 +16,8 @@ from absl.testing import absltest from torax.sources import electron_density_sources as eds +from torax.sources import runtime_params from torax.sources import source as source_lib -from torax.sources import source_config from torax.sources.tests import test_lib @@ -28,8 +28,8 @@ class GasPuffSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=eds.GasPuffSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) @@ -42,8 +42,8 @@ class PelletSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=eds.PelletSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) @@ -56,8 +56,8 @@ class NBISourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=eds.NBIParticleSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) @@ -70,8 +70,8 @@ class RecombinationDensitySinkTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=eds.RecombinationDensitySink, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) diff --git a/torax/sources/tests/external_current_source.py b/torax/sources/tests/external_current_source.py index 2d5b4788..ce8a02b7 100644 --- a/torax/sources/tests/external_current_source.py +++ b/torax/sources/tests/external_current_source.py @@ -22,8 +22,8 @@ from torax import config_slice from torax import geometry from torax.sources import external_current_source +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_config from torax.sources.tests import test_lib @@ -34,8 +34,8 @@ class ExternalCurrentSourceTest(test_lib.SourceTestCase): def setUpClass(cls): super().setUpClass( source_class=external_current_source.ExternalCurrentSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params_lib.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), ) @@ -44,21 +44,26 @@ def test_source_value(self): """Tests that a formula-based source provides values.""" source = external_current_source.ExternalCurrentSource() config = config_lib.Config() - dynamic_slice = config_slice.build_dynamic_config_slice(config) + dynamic_slice = config_slice.build_dynamic_config_slice( + config, + sources={ + 'jext': source.runtime_params, + }, + ) self.assertIsInstance(source, external_current_source.ExternalCurrentSource) # Must be circular for jext_hires call. geo = geometry.build_circular_geometry(config) self.assertIsNotNone( source.get_value( - source_type=dynamic_slice.sources[source.name].source_type, dynamic_config_slice=dynamic_slice, + dynamic_source_runtime_params=dynamic_slice.sources['jext'], geo=geo, ) ) self.assertIsNotNone( source.jext_hires( - source_type=dynamic_slice.sources[source.name].source_type, dynamic_config_slice=dynamic_slice, + dynamic_source_runtime_params=dynamic_slice.sources['jext'], geo=geo, ) ) @@ -67,13 +72,19 @@ def test_invalid_source_types_raise_errors(self): config = config_lib.Config() geo = geometry.build_circular_geometry(config) source = external_current_source.ExternalCurrentSource() - dynamic_slice = config_slice.build_dynamic_config_slice(config) - for unsupported_type in self._unsupported_types: - with self.subTest(unsupported_type.name): + for unsupported_mode in self._unsupported_modes: + with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): + source.runtime_params.mode = unsupported_mode + dynamic_slice = config_slice.build_dynamic_config_slice( + config, + sources={ + 'jext': source.runtime_params, + }, + ) source.get_value( - source_type=unsupported_type.value, dynamic_config_slice=dynamic_slice, + dynamic_source_runtime_params=dynamic_slice.sources['jext'], geo=geo, ) diff --git a/torax/sources/tests/formulas.py b/torax/sources/tests/formulas.py index 520b4226..25c99ac3 100644 --- a/torax/sources/tests/formulas.py +++ b/torax/sources/tests/formulas.py @@ -21,11 +21,11 @@ from torax import geometry from torax import sim as sim_lib from torax import state as state_lib +from torax.sources import default_sources from torax.sources import formula_config from torax.sources import formulas +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config -from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.tests.test_lib import sim_test_case from torax.transport_model import constant as constant_transport_model @@ -47,53 +47,65 @@ def test_custom_exponential_source_can_replace_puff_source(self): # For this test, use test_particle_sources_constant with the linear stepper. custom_source_name = 'custom_exponential_source' - source_models = source_models_lib.SourceModels( - additional_sources=[ - source.SingleProfileSource( - name=custom_source_name, - supported_types=( - source_config.SourceType.ZERO, - source_config.SourceType.FORMULA_BASED, - ), - affected_core_profiles=(source.AffectedCoreProfile.NE,), - formula=formulas.Exponential(custom_source_name), - ) - ] - ) - # Copy the test_particle_sources_constant config in here for clarity. - # These are the common kwargs without any of the sources. - test_particle_sources_constant_config_kwargs = dict( + test_particle_sources_constant_config = config_lib.Config( profile_conditions=config_lib.ProfileConditions( set_pedestal=True, nbar=0.85, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, # This is important to be True to test ne sources. current_eq=True, resistivity_mult=100, - bootstrap_mult=1, t_final=2, ), nu=0, - S_pellet_tot=2.0e22, - S_puff_tot=1.0e22, - S_nbi_tot=0.0, solver=config_lib.SolverConfig( predictor_corrector=False, ), ) + # Set the sources to match test_particle_sources_constant as well. + source_models = default_sources.get_default_sources() + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 2.0e22 + S_puff_tot = 1.0e22 # pylint: disable=invalid-name + puff_decay_length = 0.05 + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = ( + S_puff_tot + ) + source_models.sources[ + 'gas_puff_source' + ].runtime_params.puff_decay_length = puff_decay_length + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.0 # We need to turn off some other sources for test_particle_sources_constant # that are unrelated to our test for the ne custom source. - unrelated_source_configs = dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + runtime_params_lib.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + runtime_params_lib.Mode.ZERO + ) + + # Add the custom source to the source_models, but keep it turned off for the + # first run. + source_models.add_source( + custom_source_name, + source.SingleProfileSource( + supported_modes=( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, + ), + affected_core_profiles=(source.AffectedCoreProfile.NE,), + formula=formulas.Exponential(), + runtime_params=runtime_params_lib.RuntimeParams( + mode=runtime_params_lib.Mode.ZERO, + # will override these later, but defining here because, due to + # how JAX works, this function is still evaluated even when the + # mode is set to ZERO. So the runtime config needs to be set + # with the correct params. + formula=formula_config.Exponential(), + ), ), ) @@ -102,19 +114,9 @@ def test_custom_exponential_source_can_replace_puff_source(self): 'test_particle_sources_constant.h5', _ALL_PROFILES ) - # Set up the sim with the original config. We set up the sim only once and - # update the config on each run below in a way that does not trigger - # recompiles. This way we only trace the code once. - test_particle_sources_constant_config = config_lib.Config( - **test_particle_sources_constant_config_kwargs, - sources=dict( - **unrelated_source_configs, - # Turn off the custom source - custom_exponential_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), - ) + # We set up the sim only once and update the config on each run below in a + # way that does not trigger recompiles. This way we only trace the code + # once. geo = geometry.build_circular_geometry( test_particle_sources_constant_config ) @@ -148,48 +150,38 @@ def test_custom_exponential_source_can_replace_puff_source(self): ) with self.subTest('without_puff_and_with_custom_source'): - config_with_custom_source = config_lib.Config( - **test_particle_sources_constant_config_kwargs, - sources=dict( - **unrelated_source_configs, - custom_exponential_source=source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - formula=formula_config.FormulaConfig( - exponential=formula_config.Exponential( - total=test_particle_sources_constant_config.S_puff_tot - / test_particle_sources_constant_config.nref, - c1=1.0, - c2=test_particle_sources_constant_config.puff_decay_length, - use_normalized_r=True, - ) - ), - ), - gas_puff_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, + # Now turn on the custom source. + source_models.sources[custom_source_name].runtime_params.mode = ( + runtime_params_lib.Mode.FORMULA_BASED + ) + source_models.sources[custom_source_name].runtime_params.formula = ( + formula_config.Exponential( + total=( + S_puff_tot + / test_particle_sources_constant_config.numerics.nref ), - ), + c1=1.0, + c2=puff_decay_length, + use_normalized_r=True, + ) + ) + # And turn off the gas puff source it is replacing. + source_models.sources['gas_puff_source'].runtime_params.mode = ( + runtime_params_lib.Mode.ZERO ) self._run_sim_and_check( - config_with_custom_source, sim, ref_profiles, ref_time + test_particle_sources_constant_config, sim, ref_profiles, ref_time ) with self.subTest('without_puff_and_without_custom_source'): # Confirm that the custom source actual has an effect. - config_without_ne_sources = config_lib.Config( - **test_particle_sources_constant_config_kwargs, - sources=dict( - **unrelated_source_configs, - custom_exponential_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - gas_puff_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), + # Turn it off as well, and the check shouldn't pass. + source_models.sources[custom_source_name].runtime_params.mode = ( + runtime_params_lib.Mode.ZERO ) with self.assertRaises(AssertionError): self._run_sim_and_check( - config_without_ne_sources, sim, ref_profiles, ref_time + test_particle_sources_constant_config, sim, ref_profiles, ref_time ) def _run_sim_and_check( @@ -205,6 +197,7 @@ def _run_sim_and_check( dynamic_config_slice_provider=config_slice.DynamicConfigSliceProvider( config=config, transport_getter=lambda: sim.transport_model.runtime_params, + sources_getter=lambda: sim.source_models.runtime_params, ), geometry_provider=sim.geometry_provider, initial_state=sim.initial_state, diff --git a/torax/sources/tests/fusion_heat_source.py b/torax/sources/tests/fusion_heat_source.py index b5f09f69..5df9c06b 100644 --- a/torax/sources/tests/fusion_heat_source.py +++ b/torax/sources/tests/fusion_heat_source.py @@ -23,8 +23,8 @@ from torax import constants from torax import core_profile_setters from torax.sources import fusion_heat_source +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib from torax.tests.test_lib import torax_refs @@ -37,8 +37,8 @@ class FusionHeatSourceTest(test_lib.IonElSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=fusion_heat_source.FusionHeatSource, - unsupported_types=[ - source_config.SourceType.FORMULA_BASED, + unsupported_modes=[ + runtime_params_lib.Mode.FORMULA_BASED, ], expected_affected_core_profiles=( source.AffectedCoreProfile.TEMP_ION, @@ -59,13 +59,18 @@ def test_calc_fusion( config = references.config geo = references.geo - nref = config.nref + nref = config.numerics.nref + source_models = source_models_lib.SourceModels() + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + dynamic_config_slice=dynamic_config_slice, geo=geo, - source_models=source_models_lib.SourceModels(), + source_models=source_models, ) fusion_jax, _, _ = fusion_heat_source.calc_fusion( @@ -109,7 +114,7 @@ def calculate_fusion(config, geo, core_profiles): Pfus = ( Efus * 0.25 - * (core_profiles.ni.face_value() * config.nref) ** 2 + * (core_profiles.ni.face_value() * config.numerics.nref) ** 2 * sigmav ) # [W/m^3] Ptot = np.trapz(Pfus * geo.vpr_face, geo.r_face) / 1e6 # [MW] diff --git a/torax/sources/tests/generic_ion_el_heat_source.py b/torax/sources/tests/generic_ion_el_heat_source.py index 1aba69c5..12977eac 100644 --- a/torax/sources/tests/generic_ion_el_heat_source.py +++ b/torax/sources/tests/generic_ion_el_heat_source.py @@ -16,8 +16,8 @@ from absl.testing import absltest from torax.sources import generic_ion_el_heat_source +from torax.sources import runtime_params from torax.sources import source -from torax.sources import source_config from torax.sources.tests import test_lib @@ -28,8 +28,8 @@ class GenericIonElectronHeatSourceTest(test_lib.IonElSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=generic_ion_el_heat_source.GenericIonElectronHeatSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=( source.AffectedCoreProfile.TEMP_ION, diff --git a/torax/sources/tests/ion_el_heat_sources.py b/torax/sources/tests/ion_el_heat_sources.py index f0a195c1..14280188 100644 --- a/torax/sources/tests/ion_el_heat_sources.py +++ b/torax/sources/tests/ion_el_heat_sources.py @@ -16,8 +16,8 @@ from absl.testing import absltest from torax.sources import ion_el_heat_sources +from torax.sources import runtime_params from torax.sources import source -from torax.sources import source_config from torax.sources.tests import test_lib @@ -28,8 +28,8 @@ class BremsstrahlungHeatSinkTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.BremsstrahlungHeatSink, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_EL,), ) @@ -42,8 +42,8 @@ class ChargeExchangeHeatSinkTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.ChargeExchangeHeatSink, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_ION,), ) @@ -56,8 +56,8 @@ class CyclotronRadiationHeatSinkTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.CyclotronRadiationHeatSink, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_EL,), ) @@ -70,8 +70,8 @@ class ECRHHeatSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.ECRHHeatSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_EL,), ) @@ -84,8 +84,8 @@ class ICRHHeatSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.ICRHHeatSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_ION,), ) @@ -98,8 +98,8 @@ class LHHeatSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.LHHeatSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_EL,), ) @@ -112,8 +112,8 @@ class LineRadiationHeatSinkTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.LineRadiationHeatSink, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_EL,), ) @@ -126,8 +126,8 @@ class NBIElectronHeatSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.NBIElectronHeatSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_EL,), ) @@ -140,8 +140,8 @@ class NBIIonHeatSourceTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.NBIIonHeatSource, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_ION,), ) @@ -154,8 +154,8 @@ class RecombinationHeatSinkTest(test_lib.SingleProfileSourceTestCase): def setUpClass(cls): super().setUpClass( source_class=ion_el_heat_sources.RecombinationHeatSink, - unsupported_types=[ - source_config.SourceType.MODEL_BASED, + unsupported_modes=[ + runtime_params.Mode.MODEL_BASED, ], expected_affected_core_profiles=(source.AffectedCoreProfile.TEMP_EL,), ) diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index 6b297234..d6f20a5f 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -14,6 +14,7 @@ """Tests for qei_source.""" +import dataclasses from absl.testing import absltest import jax from torax import config as config_lib @@ -21,8 +22,8 @@ from torax import core_profile_setters from torax import geometry from torax.sources import qei_source +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_config from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib @@ -34,8 +35,8 @@ class QeiSourceTest(test_lib.SourceTestCase): def setUpClass(cls): super().setUpClass( source_class=qei_source.QeiSource, - unsupported_types=[ - source_config.SourceType.FORMULA_BASED, + unsupported_modes=[ + runtime_params_lib.Mode.FORMULA_BASED, ], expected_affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, @@ -46,21 +47,27 @@ def setUpClass(cls): def test_source_value(self): """Checks that the default implementation from Sources gives values.""" source = qei_source.QeiSource() + source_models = source_models_lib.SourceModels( + sources={'qei_source': source} + ) config = config_lib.Config() geo = geometry.build_circular_geometry(config) + static_slice = config_slice.build_static_config_slice(config) + dynamic_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_slice, + dynamic_config_slice=dynamic_slice, geo=geo, - source_models=source_models_lib.SourceModels(qei_source=source), + source_models=source_models, ) assert isinstance(source, qei_source.QeiSource) # required for pytype. - dynamic_slice = config_slice.build_dynamic_config_slice(config) - static_slice = config_slice.build_static_config_slice(config) qei = source.get_qei( - dynamic_slice.sources[source.name].source_type, static_slice, dynamic_slice, + dynamic_slice.sources['qei_source'], geo, core_profiles, ) @@ -68,23 +75,37 @@ def test_source_value(self): def test_invalid_source_types_raise_errors(self): source = qei_source.QeiSource() + source_models = source_models_lib.SourceModels( + sources={'qei_source': source} + ) config = config_lib.Config() geo = geometry.build_circular_geometry(config) + static_slice = config_slice.build_static_config_slice(config) + dynamic_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_slice, + dynamic_config_slice=dynamic_slice, geo=geo, - source_models=source_models_lib.SourceModels(qei_source=source), + source_models=source_models, ) - dynamic_slice = config_slice.build_dynamic_config_slice(config) - static_slice = config_slice.build_static_config_slice(config) - for unsupported_type in self._unsupported_types: - with self.subTest(unsupported_type.name): + for unsupported_mode in self._unsupported_modes: + with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): + dynamic_slice = config_slice.build_dynamic_config_slice( + config, + sources={ + 'qei_source': dataclasses.replace( + source.runtime_params, mode=unsupported_mode + ) + }, + ) source.get_qei( - unsupported_type.value, static_slice, dynamic_slice, + dynamic_slice.sources['qei_source'], geo, core_profiles, ) diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index 8312b551..4d82752d 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -14,6 +14,7 @@ """Tests for source_lib.py.""" +import dataclasses from absl.testing import absltest from absl.testing import parameterized import jax @@ -23,8 +24,8 @@ from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_config from torax.sources import source_models as source_models_lib @@ -34,26 +35,28 @@ class SourceTest(parameterized.TestCase): def test_zero_profile_works_by_default(self): """The default source impl should support profiles with all zeros.""" source = source_lib.Source( - name='foo', output_shape_getter=source_lib.get_cell_profile_shape, affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), ) - config = config_lib.Config( - sources={source.name: source_config.SourceConfig()} + source_models = source_models_lib.SourceModels( + sources={'foo': source}, ) + config = config_lib.Config() geo = geometry.build_circular_geometry(config) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, - source_models=source_models_lib.SourceModels( - additional_sources=[source] - ), + source_models=source_models, ) - source_type = source_config.SourceType.ZERO.value profile = source.get_value( - source_type=source_type, - dynamic_config_slice=(config_slice.build_dynamic_config_slice(config)), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -62,37 +65,39 @@ def test_zero_profile_works_by_default(self): source_lib.ProfileType.CELL.get_zero_profile(geo), ) - def test_unsupported_types_raise_errors(self): + def test_unsupported_modes_raise_errors(self): """Calling with an unsupported type should raise an error.""" source = source_lib.Source( - name='foo', - supported_types=( + supported_modes=( # Only support formula-based profiles. - source_config.SourceType.FORMULA_BASED, + runtime_params_lib.Mode.FORMULA_BASED, ), output_shape_getter=source_lib.get_cell_profile_shape, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) - config = config_lib.Config( - sources={source.name: source_config.SourceConfig()} + # But set the runtime params of the source to use ZERO as the mode. + source.runtime_params.mode = runtime_params_lib.Mode.ZERO + source_models = source_models_lib.SourceModels( + sources={'foo': source}, ) + config = config_lib.Config() geo = geometry.build_circular_geometry(config) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), - static_config_slice=config_slice.build_static_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, - source_models=source_models_lib.SourceModels( - additional_sources=[source] - ), + source_models=source_models, ) # But calling requesting ZERO shouldn't work. with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): - source_type = source_config.SourceType.ZERO.value source.get_value( - source_type=source_type, - dynamic_config_slice=( - config_slice.build_dynamic_config_slice(config) - ), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -100,33 +105,42 @@ def test_unsupported_types_raise_errors(self): def test_defaults_output_zeros(self): """The default model and formula implementations should output zeros.""" source = source_lib.Source( - name='foo', - supported_types=( - source_config.SourceType.MODEL_BASED, - source_config.SourceType.FORMULA_BASED, + supported_modes=( + runtime_params_lib.Mode.MODEL_BASED, + runtime_params_lib.Mode.FORMULA_BASED, ), output_shape_getter=source_lib.get_cell_profile_shape, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) - config = config_lib.Config( - sources={source.name: source_config.SourceConfig()} + source_models = source_models_lib.SourceModels( + sources={'foo': source}, ) + config = config_lib.Config() geo = geometry.build_circular_geometry(config) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, - source_models=source_models_lib.SourceModels( - additional_sources=[source] - ), + source_models=source_models, ) with self.subTest('model_based'): - source_type = source_config.SourceType.MODEL_BASED.value + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources={ + 'foo': dataclasses.replace( + source.runtime_params, + mode=runtime_params_lib.Mode.MODEL_BASED, + ) + }, + ) profile = source.get_value( - source_type=source_type, - dynamic_config_slice=( - config_slice.build_dynamic_config_slice(config) - ), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -135,12 +149,18 @@ def test_defaults_output_zeros(self): source_lib.ProfileType.CELL.get_zero_profile(geo), ) with self.subTest('formula'): - source_type = source_config.SourceType.FORMULA_BASED.value + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources={ + 'foo': dataclasses.replace( + source.runtime_params, + mode=runtime_params_lib.Mode.FORMULA_BASED, + ) + }, + ) profile = source.get_value( - source_type=source_type, - dynamic_config_slice=( - config_slice.build_dynamic_config_slice(config) - ), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -154,30 +174,33 @@ def test_overriding_default_formula(self): output_shape = (2, 4) # Some arbitrary shape. expected_output = jnp.ones(output_shape) source = source_lib.Source( - name='foo', - output_shape_getter=lambda _0, _1, _2: output_shape, - formula=lambda _0, _1, _2: expected_output, + output_shape_getter=lambda _0: output_shape, + formula=lambda _0, _1, _2, _3: expected_output, affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, source_lib.AffectedCoreProfile.TEMP_EL, ), ) - config = config_lib.Config( - sources={source.name: source_config.SourceConfig()} + source.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED + source_models = source_models_lib.SourceModels( + sources={'foo': source}, ) + config = config_lib.Config() geo = geometry.build_circular_geometry(config) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, - source_models=source_models_lib.SourceModels( - additional_sources=[source] - ), + source_models=source_models, ) - source_type = source_config.SourceType.FORMULA_BASED.value profile = source.get_value( - source_type=source_type, - dynamic_config_slice=(config_slice.build_dynamic_config_slice(config)), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -188,31 +211,34 @@ def test_overriding_model(self): output_shape = (2, 4) # Some arbitrary shape. expected_output = jnp.ones(output_shape) source = source_lib.Source( - name='foo', - supported_types=(source_config.SourceType.MODEL_BASED,), - output_shape_getter=lambda _0, _1, _2: output_shape, - model_func=lambda _0, _1, _2: expected_output, + supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), + output_shape_getter=lambda _0: output_shape, + model_func=lambda _0, _1, _2, _3: expected_output, affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, source_lib.AffectedCoreProfile.TEMP_EL, ), ) - config = config_lib.Config( - sources={source.name: source_config.SourceConfig()} + source.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED + source_models = source_models_lib.SourceModels( + sources={'foo': source}, ) + config = config_lib.Config() geo = geometry.build_circular_geometry(config) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, - source_models=source_models_lib.SourceModels( - additional_sources=[source] - ), + source_models=source_models, ) - source_type = source_config.SourceType.MODEL_BASED.value profile = source.get_value( - source_type=source_type, - dynamic_config_slice=(config_slice.build_dynamic_config_slice(config)), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -223,17 +249,15 @@ def test_retrieving_profile_for_affected_state(self): output_shape = (2, 4) # Some arbitrary shape. profile = jnp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]]) # from get_value() source = source_lib.Source( - name='foo', - supported_types=(source_config.SourceType.MODEL_BASED,), - output_shape_getter=lambda _0, _1, _2: output_shape, - model_func=lambda _0, _1, _2: profile, + supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), + output_shape_getter=lambda _0: output_shape, + model_func=lambda _0, _1, _2, _3: profile, affected_core_profiles=( source_lib.AffectedCoreProfile.PSI, source_lib.AffectedCoreProfile.NE, ), ) config = config_lib.Config( - sources={source.name: source_config.SourceConfig()}, numerics=config_lib.Numerics(nr=4), ) geo = geometry.build_circular_geometry(config) @@ -261,27 +285,31 @@ class SingleProfileSourceTest(parameterized.TestCase): def test_custom_formula(self): """The user-specified formula should override the default formula.""" config = config_lib.Config( - sources={'foo': source_config.SourceConfig()}, numerics=config_lib.Numerics(nr=5), ) geo = geometry.build_circular_geometry(config) - core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), - geo=geo, - # defaults are enough for this. - source_models=source_models_lib.SourceModels(), - ) expected_output = jnp.ones(5) # 5 matches config.numerics.nr. source = source_lib.SingleProfileSource( - name='foo', - formula=lambda _0, _1, _2: expected_output, + formula=lambda _0, _1, _2, _3: expected_output, affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), ) - source_type = source_config.SourceType.FORMULA_BASED.value + source.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED + source_models = source_models_lib.SourceModels( + sources={'foo': source}, + ) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) + core_profiles = core_profile_setters.initial_core_profiles( + static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=dynamic_config_slice, + geo=geo, + source_models=source_models, + ) profile = source.get_value( - source_type=source_type, - dynamic_config_slice=(config_slice.build_dynamic_config_slice(config)), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -289,33 +317,37 @@ def test_custom_formula(self): def test_multiple_profiles_raises_error(self): """A formula which outputs the wrong shape will raise an error.""" + source = source_lib.SingleProfileSource( + formula=lambda _0, _1, _2, _3: jnp.ones((2, 5)), + affected_core_profiles=( + source_lib.AffectedCoreProfile.TEMP_ION, + source_lib.AffectedCoreProfile.NE, + ), + ) + source.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED + source_models = source_models_lib.SourceModels( + sources={'foo': source}, + ) config = config_lib.Config( - sources={'foo': source_config.SourceConfig()}, numerics=config_lib.Numerics(nr=5), ) geo = geometry.build_circular_geometry(config) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, # defaults are enough for this. - source_models=source_models_lib.SourceModels(), - ) - source = source_lib.SingleProfileSource( - name='foo', - formula=lambda _0, _1, _2: jnp.ones((2, 5)), - affected_core_profiles=( - source_lib.AffectedCoreProfile.PSI, - source_lib.AffectedCoreProfile.NE, - ), + source_models=source_models, ) - source_type = source_config.SourceType.FORMULA_BASED.value with self.assertRaises(AssertionError): source.get_value( - source_type=source_type, - dynamic_config_slice=( - config_slice.build_dynamic_config_slice(config) - ), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -324,13 +356,11 @@ def test_retrieving_profile_for_affected_state(self): """Grabbing the correct profile works for all mesh state attributes.""" profile = jnp.asarray([1, 2, 3, 4]) # from get_value() source = source_lib.SingleProfileSource( - name='foo', - supported_types=(source_config.SourceType.MODEL_BASED,), - model_func=lambda _0, _1, _2: profile, + supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), + model_func=lambda _0, _1, _2, _3: profile, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) config = config_lib.Config( - sources={source.name: source_config.SourceConfig()}, numerics=config_lib.Numerics(nr=4), ) geo = geometry.build_circular_geometry(config) diff --git a/torax/sources/tests/source_config.py b/torax/sources/tests/source_config.py deleted file mode 100644 index d2451fd2..00000000 --- a/torax/sources/tests/source_config.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for source_config.py.""" - -from absl.testing import absltest -from absl.testing import parameterized -from torax import config as config_lib -from torax.sources import source_config -from torax.sources import source_models as source_models_lib - - -class SourceConfigTest(parameterized.TestCase): - """Tests for SourceConfig and related functions.""" - - def test_source_config_keys_match_default_sources(self): - """Makes sure the source configs always have the default sources.""" - config = config_lib.Config() - source_models = source_models_lib.SourceModels() - self.assertSameElements( - config.sources.keys(), source_models.all_sources.keys() - ) - - # Try overriding some elements. - config = config_lib.Config( - sources=dict( - gas_puff_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO - ), - nbi_particle_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO - ), - ) - ) - # Still should have all the same keys because Config should add back the - # defaults. - self.assertSameElements( - config.sources.keys(), source_models.all_sources.keys() - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index 4736bd08..1c934b0e 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -23,8 +23,9 @@ from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.sources import default_sources +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_config from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib @@ -43,12 +44,15 @@ class SourceProfilesTest(parameterized.TestCase): def test_computing_source_profiles_works_with_all_defaults(self): """Tests that you can compute source profiles with all defaults.""" config = torax.Config() - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) geo = torax.build_circular_geometry(config) source_models = source_models_lib.SourceModels() + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + dynamic_config_slice=dynamic_config_slice, geo=geo, source_models=source_models, ) @@ -68,7 +72,7 @@ def test_summed_temp_ion_profiles_dont_change_when_jitting(self): # fusion_heat_source, and ohmic_heat_source are included and produce # profiles for ion and electron heat. # temperature. - source_models = source_models_lib.SourceModels() + source_models = default_sources.get_default_sources() # Make some dummy source profiles that could have come from these sources. ones = jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(geo)) profiles = source_profiles_lib.SourceProfiles( @@ -110,52 +114,45 @@ def test_custom_source_profiles_dont_change_when_jitted(self): """Test that custom source profiles don't change profiles when jitted.""" source_name = 'foo' - def foo_formula(unused_dcs, geo: geometry.Geometry, unused_state): + def foo_formula( + unused_dcs, + unused_sc, + geo: geometry.Geometry, + unused_state, + ): return jnp.stack([ jnp.zeros(source_lib.ProfileType.CELL.get_profile_shape(geo)), jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(geo)), ]) foo_source = source_lib.Source( - name=source_name, # Test a fake source that somehow affects both electron temp and # electron density. affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_EL, source_lib.AffectedCoreProfile.NE, ), - supported_types=(source_config.SourceType.FORMULA_BASED,), - output_shape_getter=lambda _0, geo, _1: (2,) - + source_lib.ProfileType.CELL.get_profile_shape(geo), + supported_modes=(runtime_params_lib.Mode.FORMULA_BASED,), + output_shape_getter=( + lambda geo: (2,) + + source_lib.ProfileType.CELL.get_profile_shape(geo) + ), formula=foo_formula, ) + # Set the source mode to FORMULA. + foo_source.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED source_models = source_models_lib.SourceModels( - additional_sources=[foo_source], + sources={source_name: foo_source}, ) - zero_config = source_config.SourceConfig( - source_type=source_config.SourceType.ZERO - ) - config = torax.Config( - sources=dict( - # Turn off all the other ne sources. - gas_puff_source=zero_config, - nbi_particle_source=zero_config, - pellet_source=zero_config, - # And turn off the temp sources. - generic_ion_el_heat_source=zero_config, - fusion_heat_source=zero_config, - ohmic_heat_source=zero_config, - # But for the custom source, leave that on. - foo=source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - ), - ) + config = torax.Config() + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, ) - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) geo = torax.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=dynamic_config_slice, geo=geo, source_models=source_models, ) diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index e500b48f..3909bf9e 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -25,8 +25,8 @@ from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib -from torax.sources import source_config as source_config_lib from torax.sources import source_models as source_models_lib @@ -42,21 +42,21 @@ class SourceTestCase(parameterized.TestCase): _source_class: Type[source_lib.Source] _config_attr_name: str - _unsupported_types: Sequence[source_config_lib.SourceType] + _unsupported_modes: Sequence[runtime_params_lib.Mode] _expected_affected_core_profiles: tuple[source_lib.AffectedCoreProfile, ...] @classmethod def setUpClass( cls, source_class: Type[source_lib.Source], - unsupported_types: Sequence[source_config_lib.SourceType], + unsupported_modes: Sequence[runtime_params_lib.Mode], expected_affected_core_profiles: tuple[ source_lib.AffectedCoreProfile, ... ], ): super().setUpClass() cls._source_class = source_class - cls._unsupported_types = unsupported_types + cls._unsupported_modes = unsupported_modes cls._expected_affected_core_profiles = expected_affected_core_profiles def test_expected_mesh_states(self): @@ -83,36 +83,25 @@ def test_source_value(self): # pylint: enable=missing-kwoa self.assertIsInstance(source, source_lib.SingleProfileSource) config = config_lib.Config() - # Not all sources are in the default config, so add the source in here if - # it doesn't already exist. - if source.name not in config.sources: - supported_types = set( - [source_type for source_type in source_config_lib.SourceType] - ) - set(self._unsupported_types) - supported_type = supported_types.pop() - config = config_lib.Config( - sources={ - source.name: source_config_lib.SourceConfig( - source_type=supported_type, - ) - } - ) - source_models = source_models_lib.SourceModels( - additional_sources=[source] - ) - else: - source_models = source_models_lib.SourceModels() + source.runtime_params.mode = source.supported_modes[0] + source_models = source_models_lib.SourceModels( + sources={'foo': source}, + ) geo = geometry.build_circular_geometry(config) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config=config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, source_models=source_models, ) - source_type = config.sources[source.name].source_type.value value = source.get_value( - source_type=source_type, - dynamic_config_slice=(config_slice.build_dynamic_config_slice(config)), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -122,25 +111,34 @@ def test_invalid_source_types_raise_errors(self): """Tests that using unsupported types raises an error.""" config = config_lib.Config() geo = geometry.build_circular_geometry(config) - core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), - geo=geo, - # only need default sources here. - source_models=source_models_lib.SourceModels(), - ) # pylint: disable=missing-kwoa source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa self.assertIsInstance(source, source_lib.SingleProfileSource) - for unsupported_type in self._unsupported_types: - with self.subTest(unsupported_type.name): + source_models = source_models_lib.SourceModels( + sources={'foo': source}, + ) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config=config, + sources=source_models.runtime_params, + ) + core_profiles = core_profile_setters.initial_core_profiles( + static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=dynamic_config_slice, + geo=geo, + source_models=source_models, + ) + for unsupported_mode in self._unsupported_modes: + source.runtime_params.mode = unsupported_mode + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config=config, + sources=source_models.runtime_params, + ) + with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): source.get_value( - source_type=unsupported_type.value, - dynamic_config_slice=( - config_slice.build_dynamic_config_slice(config) - ), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -157,17 +155,23 @@ def test_source_value(self): self.assertIsInstance(source, source_lib.IonElectronSource) config = config_lib.Config() geo = geometry.build_circular_geometry(config) + source_models = source_models_lib.SourceModels( + sources={'foo': source}, + ) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config=config, + sources=source_models.runtime_params, + ) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, - # only need default sources here. - source_models=source_models_lib.SourceModels(), + source_models=source_models, ) - source_type = config.sources[source.name].source_type.value ion_and_el = source.get_value( - source_type=source_type, - dynamic_config_slice=(config_slice.build_dynamic_config_slice(config)), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -177,25 +181,35 @@ def test_invalid_source_types_raise_errors(self): """Tests that using unsupported types raises an error.""" config = config_lib.Config() geo = geometry.build_circular_geometry(config) - core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice=config_slice.build_static_config_slice(config), - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), - geo=geo, - # only need default sources here. - source_models=source_models_lib.SourceModels(), - ) # pylint: disable=missing-kwoa source = self._source_class() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa self.assertIsInstance(source, source_lib.IonElectronSource) - for unsupported_type in self._unsupported_types: - with self.subTest(unsupported_type.name): + source_models = source_models_lib.SourceModels( + sources={'foo': source}, + ) + static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config=config, + sources=source_models.runtime_params, + ) + core_profiles = core_profile_setters.initial_core_profiles( + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, + geo=geo, + source_models=source_models, + ) + for unsupported_mode in self._unsupported_modes: + source.runtime_params.mode = unsupported_mode + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config=config, + sources=source_models.runtime_params, + ) + with self.subTest(unsupported_mode.name): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): source.get_value( - source_type=unsupported_type.value, - dynamic_config_slice=( - config_slice.build_dynamic_config_slice(config) - ), + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=dynamic_config_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) diff --git a/torax/spectators/tests/plotting.py b/torax/spectators/tests/plotting.py index 00de05e2..13400989 100644 --- a/torax/spectators/tests/plotting.py +++ b/torax/spectators/tests/plotting.py @@ -19,6 +19,7 @@ import torax # We want this import to make sure jax gets set to float64 from torax import config as config_lib from torax import geometry +from torax.sources import default_sources from torax.spectators import plotting from torax.spectators import spectator from torax.stepper import linear_theta_method @@ -61,6 +62,7 @@ def _run_sim( geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, transport_model=constant_transport_model.ConstantTransportModel(), + source_models=default_sources.get_default_sources(), time_step_calculator=chi_time_step_calculator.ChiTimeStepCalculator(), ).run( spectator=observer, diff --git a/torax/stepper/linear_theta_method.py b/torax/stepper/linear_theta_method.py index 680942a6..d3c9643a 100644 --- a/torax/stepper/linear_theta_method.py +++ b/torax/stepper/linear_theta_method.py @@ -93,7 +93,6 @@ def _x_new( x_new_init, ( source_models_lib.build_all_zero_profiles( - dynamic_config_slice_t, geo, self.source_models, ), diff --git a/torax/stepper/stepper.py b/torax/stepper/stepper.py index 72f17cb6..53f920a7 100644 --- a/torax/stepper/stepper.py +++ b/torax/stepper/stepper.py @@ -139,7 +139,6 @@ def __call__( x_new = tuple() core_sources = source_models_lib.build_all_zero_profiles( source_models=self.source_models, - dynamic_config_slice=dynamic_config_slice_t, geo=geo, ) core_transport = state.CoreTransport.zeros(geo) diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index a66230c5..a1991c30 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -22,6 +22,7 @@ from torax import constants from torax import core_profile_setters from torax import geometry +from torax.sources import source_models as source_models_lib class BoundaryConditionsTest(absltest.TestCase): @@ -44,17 +45,21 @@ def test_setting_boundary_conditions(self): ) geo = geometry.build_circular_geometry(config) + source_models = source_models_lib.SourceModels() static_config_slice = config_slice.build_static_config_slice(config) initial_dynamic_config_slice = config_slice.build_dynamic_config_slice( - config + config, + sources=source_models.runtime_params, ) core_profiles = core_profile_setters.initial_core_profiles( static_config_slice, initial_dynamic_config_slice, geo, + source_models=source_models, ) dynamic_config_slice = config_slice.build_dynamic_config_slice( config, + sources=source_models.runtime_params, t=0.5, ) diff --git a/torax/tests/config_slice.py b/torax/tests/config_slice.py index 86d9316c..a899493f 100644 --- a/torax/tests/config_slice.py +++ b/torax/tests/config_slice.py @@ -20,9 +20,10 @@ import numpy as np from torax import config as config_lib from torax import config_slice as config_slice_lib +from torax.sources import electron_density_sources +from torax.sources import external_current_source from torax.sources import formula_config -from torax.sources import source_config -from torax.sources import source_models as source_models_lib +from torax.sources import runtime_params from torax.transport_model import runtime_params as transport_params_lib @@ -39,14 +40,6 @@ def foo(config_slice: config_slice_lib.DynamicConfigSlice): # Make sure you can call the function with dynamic_slice as an arg. foo_jitted(dynamic_slice) - def test_dynamic_sources_config_contains_all_sources(self): - """Tests that all the Sources attributes are covered by SourcesConfig.""" - source_models = source_models_lib.SourceModels().all_sources.keys() - sources_config_fields = config_slice_lib.build_dynamic_config_slice( - config_lib.Config() - ).sources.keys() - self.assertSameElements(source_models, sources_config_fields) - def test_time_dependent_provider_is_time_dependent(self): """Tests that the config slice provider is time dependent.""" config = config_lib.Config( @@ -57,6 +50,7 @@ def test_time_dependent_provider_is_time_dependent(self): provider = config_slice_lib.DynamicConfigSliceProvider( config=config, transport_getter=transport_params_lib.RuntimeParams, + sources_getter=lambda: {}, ) dynamic_config_slice = provider(t=1.0) np.testing.assert_allclose( @@ -133,115 +127,140 @@ def test_source_formula_config_has_time_dependent_params(self): with self.subTest('default_ne_sources'): # Check that the config params for the default ne sources are # time-dependent. - config = config_lib.Config( - pellet_width={0.0: 0.0, 1.0: 1.0}, - pellet_deposition_location={0.0: 0.0, 1.0: 2.0}, - S_pellet_tot={0.0: 0.0, 1.0: 3.0}, - puff_decay_length={0.0: 0.0, 1.0: 4.0}, - S_puff_tot={0.0: 0.0, 1.0: 5.0}, - nbi_particle_width={0.0: 0.0, 1.0: 6.0}, - nbi_deposition_location={0.0: 0.0, 1.0: 7.0}, - S_nbi_tot={0.0: 0.0, 1.0: 8.0}, - ) - dcs = config_slice_lib.build_dynamic_config_slice(config, t=0.5) - np.testing.assert_allclose(dcs.pellet_width, 0.5) - np.testing.assert_allclose(dcs.pellet_deposition_location, 1.0) - np.testing.assert_allclose(dcs.S_pellet_tot, 1.5) - np.testing.assert_allclose(dcs.puff_decay_length, 2.0) - np.testing.assert_allclose(dcs.S_puff_tot, 2.5) - np.testing.assert_allclose(dcs.nbi_particle_width, 3.0) - np.testing.assert_allclose(dcs.nbi_deposition_location, 3.5) - np.testing.assert_allclose(dcs.S_nbi_tot, 4.0) - - with self.subTest('exponential_formula'): - config = config_lib.Config( + config = config_lib.Config() + dcs = config_slice_lib.build_dynamic_config_slice( + config=config, sources={ - 'gas_puff_source': source_config.SourceConfig( - formula=formula_config.FormulaConfig( - exponential=formula_config.Exponential( - total={0.0: 0.0, 1.0: 1.0}, - c1={0.0: 0.0, 1.0: 2.0}, - c2={0.0: 0.0, 1.0: 3.0}, - ) + 'gas_puff_source': electron_density_sources.GasPuffRuntimeParams( + puff_decay_length={0.0: 0.0, 1.0: 4.0}, + S_puff_tot={0.0: 0.0, 1.0: 5.0}, + ), + 'pellet_source': electron_density_sources.PelletRuntimeParams( + pellet_width={0.0: 0.0, 1.0: 1.0}, + pellet_deposition_location={0.0: 0.0, 1.0: 2.0}, + S_pellet_tot={0.0: 0.0, 1.0: 3.0}, + ), + 'nbi_particle_source': ( + electron_density_sources.NBIParticleRuntimeParams( + nbi_particle_width={0.0: 0.0, 1.0: 6.0}, + nbi_deposition_location={0.0: 0.0, 1.0: 7.0}, + S_nbi_tot={0.0: 0.0, 1.0: 8.0}, ) - ) - } + ), + }, + t=0.5, + ) + assert isinstance( + dcs.sources['pellet_source'], + electron_density_sources.DynamicPelletRuntimeParams, + ) + assert isinstance( + dcs.sources['gas_puff_source'], + electron_density_sources.DynamicGasPuffRuntimeParams, + ) + assert isinstance( + dcs.sources['nbi_particle_source'], + electron_density_sources.DynamicNBIParticleRuntimeParams, ) - dcs = config_slice_lib.build_dynamic_config_slice(config, t=0.25) + np.testing.assert_allclose(dcs.sources['pellet_source'].pellet_width, 0.5) np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.exponential.total, 0.25 + dcs.sources['pellet_source'].pellet_deposition_location, 1.0 ) + np.testing.assert_allclose(dcs.sources['pellet_source'].S_pellet_tot, 1.5) np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.exponential.c1, 0.5 + dcs.sources['gas_puff_source'].puff_decay_length, 2.0 ) + np.testing.assert_allclose(dcs.sources['gas_puff_source'].S_puff_tot, 2.5) np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.exponential.c2, 0.75 + dcs.sources['nbi_particle_source'].nbi_particle_width, 3.0 + ) + np.testing.assert_allclose( + dcs.sources['nbi_particle_source'].nbi_deposition_location, 3.5 + ) + np.testing.assert_allclose( + dcs.sources['nbi_particle_source'].S_nbi_tot, 4.0 ) - with self.subTest('gaussian_formula'): - config = config_lib.Config( + with self.subTest('exponential_formula'): + config = config_lib.Config() + dcs = config_slice_lib.build_dynamic_config_slice( + config=config, sources={ - 'gas_puff_source': source_config.SourceConfig( - formula=formula_config.FormulaConfig( - gaussian=formula_config.Gaussian( - total={0.0: 0.0, 1.0: 1.0}, - c1={0.0: 0.0, 1.0: 2.0}, - c2={0.0: 0.0, 1.0: 3.0}, - ) + 'gas_puff_source': runtime_params.RuntimeParams( + formula=formula_config.Exponential( + total={0.0: 0.0, 1.0: 1.0}, + c1={0.0: 0.0, 1.0: 2.0}, + c2={0.0: 0.0, 1.0: 3.0}, ) - ) - } + ), + }, + t=0.25, ) - dcs = config_slice_lib.build_dynamic_config_slice(config, t=0.25) - np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.gaussian.total, 0.25 + assert isinstance( + dcs.sources['gas_puff_source'].formula, + formula_config.DynamicExponential, ) np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.gaussian.c1, 0.5 + dcs.sources['gas_puff_source'].formula.total, 0.25 ) + np.testing.assert_allclose(dcs.sources['gas_puff_source'].formula.c1, 0.5) np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.gaussian.c2, 0.75 + dcs.sources['gas_puff_source'].formula.c2, 0.75 ) - with self.subTest('custom_formula'): - config = config_lib.Config( + with self.subTest('gaussian_formula'): + config = config_lib.Config() + dcs = config_slice_lib.build_dynamic_config_slice( + config=config, sources={ - 'gas_puff_source': source_config.SourceConfig( - formula=formula_config.FormulaConfig( - custom_params={ - 'foo': 1.0, - 'bar': {0.0: 0.0, 1.0: 1.0}, - 'baz': config_lib.InterpolationParam( - {0.0: 0.0, 1.0: 2.0}, - ), - } + 'gas_puff_source': runtime_params.RuntimeParams( + formula=formula_config.Gaussian( + total={0.0: 0.0, 1.0: 1.0}, + c1={0.0: 0.0, 1.0: 2.0}, + c2={0.0: 0.0, 1.0: 3.0}, ) - ) - } + ), + }, + t=0.25, ) - dcs = config_slice_lib.build_dynamic_config_slice(config, t=1.0) - np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.custom_params['foo'], 1.0 + assert isinstance( + dcs.sources['gas_puff_source'].formula, formula_config.DynamicGaussian ) np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.custom_params['bar'], 1.0 + dcs.sources['gas_puff_source'].formula.total, 0.25 ) + np.testing.assert_allclose(dcs.sources['gas_puff_source'].formula.c1, 0.5) np.testing.assert_allclose( - dcs.sources['gas_puff_source'].formula.custom_params['baz'], 2.0 + dcs.sources['gas_puff_source'].formula.c2, 0.75 ) def test_wext_in_dynamic_config_cannot_be_negative(self): """Tests that wext cannot be negative.""" - config = config_lib.Config(wext={0.0: 1.0, 1.0: -1.0}) + config = config_lib.Config() + dcs_provider = config_slice_lib.DynamicConfigSliceProvider( + config=config, + transport_getter=transport_params_lib.RuntimeParams, + sources_getter=lambda: { + 'jext': external_current_source.RuntimeParams( + wext={0.0: 1.0, 1.0: -1.0} + ), + }, + ) # While wext is positive, this should be fine. - dcs = config_slice_lib.build_dynamic_config_slice(config, t=0.0) - np.testing.assert_allclose(dcs.wext, 1.0) + dcs = dcs_provider(t=0.0) + assert isinstance( + dcs.sources['jext'], external_current_source.DynamicRuntimeParams + ) + np.testing.assert_allclose(dcs.sources['jext'].wext, 1.0) # Even 0 should be fine. - dcs = config_slice_lib.build_dynamic_config_slice(config, t=0.5) - np.testing.assert_allclose(dcs.wext, 0.0) + dcs = dcs_provider(t=0.5) + assert isinstance( + dcs.sources['jext'], external_current_source.DynamicRuntimeParams + ) + np.testing.assert_allclose(dcs.sources['jext'].wext, 0.0) # But negative values will cause an error. with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): - config_slice_lib.build_dynamic_config_slice(config, t=1.0) + dcs_provider(t=1.0) if __name__ == '__main__': diff --git a/torax/tests/physics.py b/torax/tests/physics.py index e374af44..54235015 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -24,6 +24,7 @@ from torax import core_profile_setters from torax import geometry from torax import physics +from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.tests.test_lib import torax_refs @@ -101,7 +102,14 @@ def test_update_psi_from_j( references = references_getter() config = references.config - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) + source_models = source_models_lib.SourceModels() + # Turn on the external current source. + source_models.jext.runtime_params.mode = ( + source_runtime_params.Mode.FORMULA_BASED + ) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, sources=source_models.runtime_params + ) geo = references.geo # pylint: disable=protected-access @@ -109,7 +117,7 @@ def test_update_psi_from_j( currents = core_profile_setters._prescribe_currents_no_bootstrap( dynamic_config_slice, geo, - source_models=source_models_lib.SourceModels(), + source_models=source_models, ) psi = core_profile_setters._update_psi_from_j( dynamic_config_slice, geo, currents diff --git a/torax/tests/sim.py b/torax/tests/sim.py index 3a9a599f..59b07937 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -27,6 +27,7 @@ import torax from torax import sim as sim_lib from torax import state as state_lib +from torax.sources import source_models as source_models_lib from torax.spectators import spectator as spectator_lib from torax.stepper import linear_theta_method from torax.tests.test_lib import explicit_stepper @@ -417,6 +418,7 @@ def test_no_op(self): geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, transport_model=constant_transport_model.ConstantTransportModel(), + source_models=source_models_lib.SourceModels(), time_step_calculator=time_step_calculator, ) @@ -478,6 +480,7 @@ def test_observers_update_during_runs(self, stepper): geo=geo, stepper_builder=stepper, transport_model=config_module.get_transport_model(), + source_models=config_module.get_sources(), time_step_calculator=time_step_calculator, ) sim.run( diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index ac887da5..3a00c2e7 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -14,6 +14,10 @@ """Tests for using custom, user-defined sources/sinks within TORAX.""" +from __future__ import annotations + +import dataclasses + from absl.testing import absltest import chex from torax import config as config_lib @@ -21,10 +25,11 @@ from torax import geometry from torax import sim as sim_lib from torax import state as state_lib +from torax.runtime_params import config_slice_args +from torax.sources import default_sources from torax.sources import electron_density_sources +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config -from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.tests.test_lib import sim_test_case from torax.transport_model import constant as constant_transport_model @@ -43,101 +48,136 @@ def test_custom_ne_source_can_replace_defaults(self): # stepper. custom_source_name = 'custom_ne_source' - def custom_source_formula(dynamic_config, geo, unused_state): - # Combine the outputs of the pellet + def custom_source_formula( + dynamic_config_slice: config_slice.DynamicConfigSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, + geo: geometry.Geometry, + unused_state: state_lib.CoreProfiles | None, + ): + # Combine the outputs. + assert isinstance( + dynamic_source_runtime_params, _CustomSourceDynamicRuntimeParams + ) + ignored_default_kwargs = dict( + mode=dynamic_source_runtime_params.mode, + is_explicit=dynamic_source_runtime_params.is_explicit, + formula=dynamic_source_runtime_params.formula, + ) + puff_params = electron_density_sources.DynamicGasPuffRuntimeParams( + puff_decay_length=dynamic_source_runtime_params.puff_decay_length, + S_puff_tot=dynamic_source_runtime_params.S_puff_tot, + **ignored_default_kwargs, + ) + nbi_params = electron_density_sources.DynamicNBIParticleRuntimeParams( + nbi_deposition_location=dynamic_source_runtime_params.nbi_deposition_location, + nbi_particle_width=dynamic_source_runtime_params.nbi_particle_width, + S_nbi_tot=dynamic_source_runtime_params.S_nbi_tot, + **ignored_default_kwargs, + ) + pellet_params = electron_density_sources.DynamicPelletRuntimeParams( + pellet_deposition_location=dynamic_source_runtime_params.pellet_deposition_location, + pellet_width=dynamic_source_runtime_params.pellet_width, + S_pellet_tot=dynamic_source_runtime_params.S_pellet_tot, + **ignored_default_kwargs, + ) + # pylint: disable=protected-access return ( - electron_density_sources.calc_puff_source( - geo, - puff_decay_length=dynamic_config.puff_decay_length, - S_puff_tot=dynamic_config.S_puff_tot, - nref=dynamic_config.nref, + electron_density_sources._calc_puff_source( + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=puff_params, + geo=geo, ) - + electron_density_sources.calc_nbi_source( - geo, - nbi_deposition_location=dynamic_config.nbi_deposition_location, - nbi_particle_width=dynamic_config.nbi_particle_width, - S_nbi_tot=dynamic_config.S_nbi_tot, - nref=dynamic_config.nref, + + electron_density_sources._calc_nbi_source( + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=nbi_params, + geo=geo, ) - + electron_density_sources.calc_pellet_source( - geo, - pellet_deposition_location=( - dynamic_config.pellet_deposition_location - ), - pellet_width=dynamic_config.pellet_width, - S_pellet_tot=dynamic_config.S_pellet_tot, - nref=dynamic_config.nref, + + electron_density_sources._calc_pellet_source( + dynamic_config_slice=dynamic_config_slice, + dynamic_source_runtime_params=pellet_params, + geo=geo, ) ) + # pylint: enable=protected-access - source_models = source_models_lib.SourceModels( - additional_sources=[ - source.SingleProfileSource( - name=custom_source_name, - supported_types=( - source_config.SourceType.ZERO, - source_config.SourceType.FORMULA_BASED, - ), - affected_core_profiles=(source.AffectedCoreProfile.NE,), - formula=custom_source_formula, - ) - ] + # First instantiate the same default sources that test_particle_sources + # constant starts with. + source_models = default_sources.get_default_sources() + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1 + source_models.qei_source.runtime_params.Qei_mult = 1 + nbi_params = source_models.sources['nbi_particle_source'].runtime_params + assert isinstance( + nbi_params, electron_density_sources.NBIParticleRuntimeParams + ) + nbi_params.S_nbi_tot = 0.0 + pellet_params = source_models.sources['pellet_source'].runtime_params + assert isinstance( + pellet_params, electron_density_sources.PelletRuntimeParams + ) + pellet_params.S_pellet_tot = 2.0e22 + gas_puff_params = source_models.sources['gas_puff_source'].runtime_params + assert isinstance( + gas_puff_params, electron_density_sources.GasPuffRuntimeParams + ) + gas_puff_params.S_puff_tot = 1.0e22 + # Turn off some sources. + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + runtime_params_lib.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + runtime_params_lib.Mode.ZERO + ) + + # Add the custom source with the correct params, but keep it turned off to + # start. + source_models.add_source( + source_name=custom_source_name, + source=source.SingleProfileSource( + supported_modes=( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, + ), + affected_core_profiles=(source.AffectedCoreProfile.NE,), + formula=custom_source_formula, + runtime_params=_CustomSourceRuntimeParams( + mode=runtime_params_lib.Mode.ZERO, + puff_decay_length=gas_puff_params.puff_decay_length, + S_puff_tot=gas_puff_params.S_puff_tot, + nbi_particle_width=nbi_params.nbi_particle_width, + nbi_deposition_location=nbi_params.nbi_deposition_location, + S_nbi_tot=nbi_params.S_nbi_tot, + pellet_width=pellet_params.pellet_width, + pellet_deposition_location=pellet_params.pellet_deposition_location, + S_pellet_tot=pellet_params.S_pellet_tot, + ), + ), ) # Copy the test_particle_sources_constant config in here for clarity. # These are the common kwargs without any of the sources. - test_particle_sources_constant_config_kwargs = dict( + test_particle_sources_constant_config = config_lib.Config( profile_conditions=config_lib.ProfileConditions( set_pedestal=True, nbar=0.85, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, # This is important to be True to test ne sources. current_eq=True, resistivity_mult=100, - bootstrap_mult=1, t_final=2, ), nu=0, - S_pellet_tot=2.0e22, - S_puff_tot=1.0e22, - S_nbi_tot=0.0, solver=config_lib.SolverConfig( predictor_corrector=False, ), ) - # We need to turn off some other sources for test_particle_sources_constant - # that are unrelated to our test for the ne custom source. - unrelated_source_configs = dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ) # Load reference profiles ref_profiles, ref_time = self._get_refs( 'test_particle_sources_constant.h5', _ALL_PROFILES ) - - # Set up the sim with the original config. We set up the sim only once and - # update the config on each run below in a way that does not trigger - # recompiles. This way we only trace the code once. - test_particle_sources_constant_config = config_lib.Config( - **test_particle_sources_constant_config_kwargs, - sources=dict( - **unrelated_source_configs, - # Turn off the custom source - custom_ne_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), - ) geo = geometry.build_circular_geometry( test_particle_sources_constant_config ) @@ -170,51 +210,25 @@ def custom_source_formula(dynamic_config, geo, unused_state): ) with self.subTest('without_defaults_and_with_custom_source'): - config_with_custom_source = config_lib.Config( - **test_particle_sources_constant_config_kwargs, - sources=dict( - **unrelated_source_configs, - custom_ne_source=source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - ), - gas_puff_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - nbi_particle_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - pellet_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), + # Turn off the other sources and turn on the custom one. + nbi_params.mode = runtime_params_lib.Mode.ZERO + pellet_params.mode = runtime_params_lib.Mode.ZERO + gas_puff_params.mode = runtime_params_lib.Mode.ZERO + source_models.sources[custom_source_name].runtime_params.mode = ( + runtime_params_lib.Mode.FORMULA_BASED ) self._run_sim_and_check( - config_with_custom_source, sim, ref_profiles, ref_time + test_particle_sources_constant_config, sim, ref_profiles, ref_time ) with self.subTest('without_defaults_and_without_custom_source'): # Confirm that the custom source actual has an effect. - config_without_ne_sources = config_lib.Config( - **test_particle_sources_constant_config_kwargs, - sources=dict( - **unrelated_source_configs, - custom_ne_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - gas_puff_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - nbi_particle_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - pellet_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), + source_models.sources[custom_source_name].runtime_params.mode = ( + runtime_params_lib.Mode.ZERO ) with self.assertRaises(AssertionError): self._run_sim_and_check( - config_without_ne_sources, sim, ref_profiles, ref_time + test_particle_sources_constant_config, sim, ref_profiles, ref_time ) def _run_sim_and_check( @@ -233,6 +247,7 @@ def _run_sim_and_check( config_slice.DynamicConfigSliceProvider( config=config, transport_getter=lambda: sim.transport_model.runtime_params, + sources_getter=lambda: sim.source_models.runtime_params, ) ), static_config_slice=sim.static_config_slice, @@ -250,5 +265,49 @@ def _run_sim_and_check( ) +# pylint: disable=invalid-name + + +@dataclasses.dataclass(kw_only=True) +class _CustomSourceRuntimeParams(runtime_params_lib.RuntimeParams): + """Runtime params for the custom source defined in the test case above.""" + + puff_decay_length: runtime_params_lib.TimeDependentField + S_puff_tot: runtime_params_lib.TimeDependentField + nbi_particle_width: runtime_params_lib.TimeDependentField + nbi_deposition_location: runtime_params_lib.TimeDependentField + S_nbi_tot: runtime_params_lib.TimeDependentField + pellet_width: runtime_params_lib.TimeDependentField + pellet_deposition_location: runtime_params_lib.TimeDependentField + S_pellet_tot: runtime_params_lib.TimeDependentField + + def build_dynamic_params( + self, t: chex.Numeric + ) -> _CustomSourceDynamicRuntimeParams: + return _CustomSourceDynamicRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=_CustomSourceDynamicRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class _CustomSourceDynamicRuntimeParams( + runtime_params_lib.DynamicRuntimeParams +): + puff_decay_length: float + S_puff_tot: float + nbi_particle_width: float + nbi_deposition_location: float + S_nbi_tot: float + pellet_width: float + pellet_deposition_location: float + S_pellet_tot: float + + +# pylint: enable=invalid-name + if __name__ == '__main__': absltest.main() diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index c6ea52c3..e623981e 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -17,10 +17,13 @@ This is a separate file to not bloat the main sim.py test file. """ +from __future__ import annotations + import dataclasses from typing import Any from absl.testing import absltest +import chex from jax import numpy as jnp import numpy as np from torax import config as config_lib @@ -30,8 +33,10 @@ from torax import sim as sim_lib from torax import state as state_module from torax.fvm import cell_variable +from torax.runtime_params import config_slice_args +from torax.sources import default_sources +from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_config from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib from torax.tests.test_lib import explicit_stepper @@ -50,9 +55,12 @@ def test_merging_source_profiles(self): """Tests that the implicit and explicit source profiles merge correctly.""" config = config_lib.Config() geo = geometry.build_circular_geometry(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) + source_models = default_sources.get_default_sources() + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + sources=source_models.runtime_params, + ) static_config_slice = config_slice.build_static_config_slice(config) - source_models = source_models_lib.SourceModels() # Technically, the _merge_source_profiles() function should be called with # source profiles where, for every source, only one of the implicit or # explicit profiles has non-zero values. That is what makes the summing @@ -60,14 +68,12 @@ def test_merging_source_profiles(self): # summed in the first place. # Build a fake set of source profiles which have all 1s in all the profiles. fake_implicit_source_profiles = _build_source_profiles_with_single_value( - dynamic_config_slice=dynamic_config_slice, geo=geo, source_models=source_models, value=1.0, ) # And a fake set of profiles with all 2s. fake_explicit_source_profiles = _build_source_profiles_with_single_value( - dynamic_config_slice=dynamic_config_slice, geo=geo, source_models=source_models, value=2.0, @@ -98,7 +104,7 @@ def test_merging_source_profiles(self): # All the profiles in the merged profiles should be a 1D array with all 3s. # Except the Qei profile, which is a special case. for name, profile in merged_profiles.profiles.items(): - if name != source_models.qei_source.name: + if name != source_models.qei_source_name: np.testing.assert_allclose(profile, 3.0) else: np.testing.assert_allclose(profile, 6.0) @@ -113,56 +119,53 @@ def test_first_and_last_source_profiles(self): # The first time step and last time step's output source profiles are built # in a special way that combines the implicit and explicit profiles. - # Create custom sources which output profiles depending on the pellet_width. - def custom_source_formula(dynamic_config, geo, unused_state): - # Combine the outputs of the pellet - return jnp.ones_like(geo.r) * dynamic_config.pellet_width + # Create custom sources whose output profiles depend on Tiped. + # This is not physically realistic, just for testing purposes. + def custom_source_formula( + unused_dynamic_config, + source_conf, + geo, + unused_state, + ): + return jnp.ones_like(geo.r) * source_conf.foo # Include 2 versions of this source, one implicit and one explicit. source_models = source_models_lib.SourceModels( - additional_sources=[ - source.SingleProfileSource( - name='implicit_ne_source', - supported_types=( - source_config.SourceType.ZERO, - source_config.SourceType.FORMULA_BASED, + sources={ + 'implicit_ne_source': source.SingleProfileSource( + supported_modes=( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, ), affected_core_profiles=(source.AffectedCoreProfile.NE,), formula=custom_source_formula, + runtime_params=_FakeSourceRuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + foo={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, + ), ), - source.SingleProfileSource( - name='explicit_ne_source', - supported_types=( - source_config.SourceType.ZERO, - source_config.SourceType.FORMULA_BASED, + 'explicit_ne_source': source.SingleProfileSource( + supported_modes=( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.FORMULA_BASED, ), affected_core_profiles=(source.AffectedCoreProfile.NE,), formula=custom_source_formula, + runtime_params=_FakeSourceRuntimeParams( + mode=runtime_params_lib.Mode.FORMULA_BASED, + foo={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, + ), ), - ] - ) - # Linearly scale the pellet_width. - config = config_lib.Config( - pellet_width={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, - sources={ - 'implicit_ne_source': source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - is_explicit=False, - ), - 'explicit_ne_source': source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - is_explicit=True, - ), - }, + } ) + config = config_lib.Config() geo = geometry.build_circular_geometry(config) time_stepper = _FakeTimeStepCalculator() step_fn = _FakeSimulationStepFn(time_stepper, source_models) - dynamic_config_slice_provider = ( - config_slice.DynamicConfigSliceProvider( - config=config, - transport_getter=constant_transport_model.RuntimeParams, - ) + dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( + config=config, + transport_getter=constant_transport_model.RuntimeParams, + sources_getter=lambda: source_models.runtime_params, ) initial_dcs = dynamic_config_slice_provider(0.0) static_config_slice = config_slice.build_static_config_slice(config) @@ -197,7 +200,6 @@ def custom_source_formula(dynamic_config, geo, unused_state): def _build_source_profiles_with_single_value( - dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, source_models: source_models_lib.SourceModels, value: float, @@ -206,12 +208,7 @@ def _build_source_profiles_with_single_value( face_1d_arr = jnp.ones_like(geo.r_face) * value return source_profiles_lib.SourceProfiles( profiles={ - name: ( - jnp.ones( - shape=src.output_shape_getter(dynamic_config_slice, geo, None) - ) - * value - ) + name: jnp.ones(shape=src.output_shape_getter(geo)) * value for name, src in source_models.standard_sources.items() }, j_bootstrap=source_profiles_lib.BootstrapCurrentProfile( @@ -257,6 +254,27 @@ def next_dt( return jnp.ones(()), () +@dataclasses.dataclass(kw_only=True) +class _FakeSourceRuntimeParams(runtime_params_lib.RuntimeParams): + foo: runtime_params_lib.TimeDependentField + + def build_dynamic_params( + self, t: chex.Numeric + ) -> _FakeSourceDynamicRuntimeParams: + return _FakeSourceDynamicRuntimeParams( + **config_slice_args.get_init_kwargs( + input_config=self, + output_type=_FakeSourceDynamicRuntimeParams, + t=t, + ) + ) + + +@chex.dataclass(frozen=True) +class _FakeSourceDynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + foo: float + + class _FakeSimulationStepFn(sim_lib.SimulationStepFn): """Fake step function which only calculates new implicit profiles.""" diff --git a/torax/tests/sim_time_dependence.py b/torax/tests/sim_time_dependence.py index 6a3bc191..de7570be 100644 --- a/torax/tests/sim_time_dependence.py +++ b/torax/tests/sim_time_dependence.py @@ -79,6 +79,7 @@ def test_time_dependent_params_update_in_adaptive_dt( dynamic_config_slice_provider = config_slice.DynamicConfigSliceProvider( config=config, transport_getter=lambda: transport.runtime_params, + sources_getter=lambda: source_models.runtime_params, ) initial_dynamic_config_slice = dynamic_config_slice_provider( config.numerics.t_initial @@ -164,7 +165,6 @@ def __call__( ) # Use Qei as a hacky way to extract what the combined value was. core_sources = source_models_lib.build_all_zero_profiles( - dynamic_config_slice=dynamic_config_slice_t, geo=geo, source_models=self.source_models, ) diff --git a/torax/tests/state.py b/torax/tests/state.py index b2d76ad4..7e6280b6 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -39,15 +39,19 @@ def setUp(self): # Make a State object in history mode, output by scan self.history_length = 2 + source_models = source_models_lib.SourceModels() def make_hist(config, geo): initial_counter = jnp.array(0) def scan_f(counter: jax.Array, _) -> tuple[jax.Array, state.CoreProfiles]: core_profiles = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config), - config_slice.build_dynamic_config_slice(config), - geo, + static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice( + config, sources=source_models.runtime_params + ), + geo=geo, + source_models=source_models, ) # Make one variable in the history track the value of the counter value = jnp.ones_like(core_profiles.temp_ion.value) * counter @@ -82,10 +86,16 @@ def test_sanity_check( ): """Make sure State.sanity_check can be called.""" references = references_getter() + source_models = source_models_lib.SourceModels() basic_core_profiles = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(references.config), - config_slice.build_dynamic_config_slice(references.config), - references.geo, + static_config_slice=config_slice.build_static_config_slice( + references.config + ), + dynamic_config_slice=config_slice.build_dynamic_config_slice( + references.config, sources=source_models.runtime_params + ), + geo=references.geo, + source_models=source_models, ) basic_core_profiles.sanity_check() @@ -150,10 +160,14 @@ def test_initial_boundary_condition_from_time_dependent_params(self): ), ), ) + source_models = source_models_lib.SourceModels() core_profiles = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config), - config_slice.build_dynamic_config_slice(config), - geometry.build_circular_geometry(config), + static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice( + config, sources=source_models.runtime_params + ), + geo=geometry.build_circular_geometry(config), + source_models=source_models, ) np.testing.assert_allclose( core_profiles.temp_ion.right_face_constraint, 27.7 @@ -175,57 +189,68 @@ def test_initial_psi_from_j( initial_j_is_total_current=True, initial_psi_from_j=True, nu=2, - numerics=config_lib.Numerics( - bootstrap_mult=0, - ), ) config2 = config_lib.Config( initial_j_is_total_current=False, initial_psi_from_j=True, nu=2, - numerics=config_lib.Numerics( - bootstrap_mult=0, - ), ) config3 = config_lib.Config( initial_j_is_total_current=False, initial_psi_from_j=True, nu=2, - fext=0.0, - numerics=config_lib.Numerics( - bootstrap_mult=1, - ), ) # Needed to generate psi for bootstrap calculation config3_helper = config_lib.Config( initial_j_is_total_current=True, initial_psi_from_j=True, nu=2, - fext=0.0, - numerics=config_lib.Numerics( - bootstrap_mult=0, - ), ) geo = geo_builder(config1) + source_models = source_models_lib.SourceModels() + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + dcs1 = config_slice.build_dynamic_config_slice( + config1, sources=source_models.runtime_params + ) core_profiles1 = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config1), - config_slice.build_dynamic_config_slice(config1), + static_config_slice=config_slice.build_static_config_slice(config1), + dynamic_config_slice=dcs1, geo=geo, + source_models=source_models, + ) + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + dcs2 = config_slice.build_dynamic_config_slice( + config2, sources=source_models.runtime_params ) core_profiles2 = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config2), - config_slice.build_dynamic_config_slice(config2), + static_config_slice=config_slice.build_static_config_slice(config2), + dynamic_config_slice=dcs2, geo=geo, + source_models=source_models, + ) + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + source_models.jext.runtime_params.fext = 0.0 + dcs3 = config_slice.build_dynamic_config_slice( + config3, sources=source_models.runtime_params ) core_profiles3 = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config3), - config_slice.build_dynamic_config_slice(config3), + static_config_slice=config_slice.build_static_config_slice(config3), + dynamic_config_slice=dcs3, geo=geo, + source_models=source_models, + ) + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.jext.runtime_params.fext = 0.0 + dcs3_helper = config_slice.build_dynamic_config_slice( + config3_helper, sources=source_models.runtime_params ) core_profiles3_helper = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config3_helper), - config_slice.build_dynamic_config_slice(config3_helper), + static_config_slice=config_slice.build_static_config_slice( + config3_helper + ), + dynamic_config_slice=dcs3_helper, geo=geo, + source_models=source_models, ) # calculate total and Ohmic current profiles arising from nu=2 @@ -235,12 +260,17 @@ def test_initial_psi_from_j( ) ctot = config1.profile_conditions.Ip * 1e6 / denom jtot_formula_face = jformula_face * ctot - johm_formula_face = jtot_formula_face * (1 - config1.fext) + johm_formula_face = jtot_formula_face * ( + 1 - dcs1.sources[source_models.jext_name].fext # pytype: disable=attribute-error + ) # Calculate bootstrap current for config3 which doesn't zero it out source_models = source_models_lib.SourceModels() bootstrap_profile = source_models.j_bootstrap.get_value( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config3), + dynamic_config_slice=dcs3, + dynamic_source_runtime_params=dcs3.sources[ + source_models.j_bootstrap_name + ], geo=geo, temp_ion=core_profiles3.temp_ion, temp_el=core_profiles3.temp_el, @@ -293,21 +323,30 @@ def test_initial_psi_from_j( def test_initial_psi_from_geo_noop_circular(self): """Tests expected behaviour of initial psi and current options.""" + source_models = source_models_lib.SourceModels() config1 = config_lib.Config( initial_psi_from_j=False, ) + dcs1 = config_slice.build_dynamic_config_slice( + config1, sources=source_models.runtime_params + ) config2 = config_lib.Config( initial_psi_from_j=True, ) + dcs2 = config_slice.build_dynamic_config_slice( + config2, sources=source_models.runtime_params + ) core_profiles1 = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config1), - config_slice.build_dynamic_config_slice(config1), - geometry.build_circular_geometry(config1), + static_config_slice=config_slice.build_static_config_slice(config1), + dynamic_config_slice=dcs1, + geo=geometry.build_circular_geometry(config1), + source_models=source_models, ) core_profiles2 = core_profile_setters.initial_core_profiles( - config_slice.build_static_config_slice(config2), - config_slice.build_dynamic_config_slice(config2), - geometry.build_circular_geometry(config2), + static_config_slice=config_slice.build_static_config_slice(config2), + dynamic_config_slice=dcs2, + geo=geometry.build_circular_geometry(config2), + source_models=source_models, ) np.testing.assert_allclose( core_profiles1.currents.jtot, core_profiles2.currents.jtot diff --git a/torax/tests/test_data/compilation_benchmark.py b/torax/tests/test_data/compilation_benchmark.py index 37cf0c33..2a0a6426 100644 --- a/torax/tests/test_data/compilation_benchmark.py +++ b/torax/tests/test_data/compilation_benchmark.py @@ -24,6 +24,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import nonlinear_theta_method from torax.transport_model import qlknn_wrapper @@ -39,20 +42,14 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=0.0007944 * 2, ), nu=0, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, - Ptot=53.0e6, # total external heating solver=config_lib.SolverConfig( use_pereverzev=False, ), @@ -71,6 +68,31 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total heating (including accounting for radiation) r + source_models.sources['generic_ion_el_heat_source'].runtime_params.Ptot = ( + 53.0e6 + ) + # total pellet particles/s (continuous pellet model) + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 1.0e22 + # total pellet particles/s + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -80,5 +102,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=get_geometry(config), stepper_builder=nonlinear_theta_method.NewtonRaphsonThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/default_config.py b/torax/tests/test_data/default_config.py index 61fd61d1..872915da 100644 --- a/torax/tests/test_data/default_config.py +++ b/torax/tests/test_data/default_config.py @@ -17,6 +17,8 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib +from torax.sources import default_sources +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -35,6 +37,12 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -45,5 +53,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_absolute_jext.py b/torax/tests/test_data/test_absolute_jext.py index 65a666d2..0978bbde 100644 --- a/torax/tests/test_data/test_absolute_jext.py +++ b/torax/tests/test_data/test_absolute_jext.py @@ -24,7 +24,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -39,26 +41,14 @@ def get_config() -> config_lib.Config: current_eq=True, resistivity_mult=100, # to shorten current diffusion time t_final=2, - bootstrap_mult=0, # remove bootstrap current ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - use_absolute_jext=True, - fext=0.0, - Iext=3.0, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -70,6 +60,23 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: return qlknn_wrapper.QLKNNTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.jext.runtime_params.use_absolute_jext = True + source_models.jext.runtime_params.fext = 0.0 + source_models.jext.runtime_params.Iext = 3.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -80,5 +87,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_all_transport_crank_nicolson.py b/torax/tests/test_data/test_all_transport_crank_nicolson.py index 831f1639..1c31cec0 100644 --- a/torax/tests/test_data/test_all_transport_crank_nicolson.py +++ b/torax/tests/test_data/test_all_transport_crank_nicolson.py @@ -22,7 +22,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -37,23 +39,17 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=2, largeValue_n=1.0e5, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - w=0.18202270915319393, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, @@ -61,14 +57,6 @@ def get_config() -> config_lib.Config: d_per=30.0, theta_imp=0.5, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -88,6 +76,31 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 1.0e22 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -98,5 +111,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_all_transport_fusion_qlknn.py b/torax/tests/test_data/test_all_transport_fusion_qlknn.py index 25332804..08effacd 100644 --- a/torax/tests/test_data/test_all_transport_fusion_qlknn.py +++ b/torax/tests/test_data/test_all_transport_fusion_qlknn.py @@ -21,7 +21,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -36,31 +38,20 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=2, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, - Ptot=53.0e6, # total external heating solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -76,6 +67,28 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 1.0e22 + # total heating (including accounting for radiation) r + source_models.sources['generic_ion_el_heat_source'].runtime_params.Ptot = ( + 53.0e6 + ) + # total pellet particles/s + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -86,5 +99,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_bootstrap.py b/torax/tests/test_data/test_bootstrap.py index de23ac14..c67ea724 100644 --- a/torax/tests/test_data/test_bootstrap.py +++ b/torax/tests/test_data/test_bootstrap.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -32,32 +34,19 @@ def get_config() -> config_lib.Config: nbar=0.85, # initial density (in Greenwald fraction units) ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=1, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - S_pellet_tot=0.0, - S_puff_tot=0.0, - S_nbi_tot=0.0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -76,6 +65,27 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 0.0 + # total pellet particles/s + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = 0.0 + # NBI total particle source + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -86,5 +96,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_cgmheat.py b/torax/tests/test_data/test_cgmheat.py index 2006a7e4..34a5326b 100644 --- a/torax/tests/test_data/test_cgmheat.py +++ b/torax/tests/test_data/test_cgmheat.py @@ -21,7 +21,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import critical_gradient as cgm_transport_model @@ -30,20 +32,11 @@ def get_config() -> config_lib.Config: return config_lib.Config( numerics=config_lib.Numerics( t_final=1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -55,6 +48,20 @@ def get_transport_model() -> cgm_transport_model.CriticalGradientModel: return cgm_transport_model.CriticalGradientModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -65,5 +72,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_chease.py b/torax/tests/test_data/test_chease.py index 174aea3c..3e9abefc 100644 --- a/torax/tests/test_data/test_chease.py +++ b/torax/tests/test_data/test_chease.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -32,25 +34,11 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, t_final=1, - bootstrap_mult=0, # remove bootstrap current ), - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -66,6 +54,32 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -76,5 +90,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_crank_nicolson.py b/torax/tests/test_data/test_crank_nicolson.py index dc4d151b..662321d0 100644 --- a/torax/tests/test_data/test_crank_nicolson.py +++ b/torax/tests/test_data/test_crank_nicolson.py @@ -21,7 +21,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -34,22 +36,12 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, t_final=1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=0.5, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -61,6 +53,22 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -71,5 +79,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_exact_finaltime.py b/torax/tests/test_data/test_exact_finaltime.py index 18fb39d3..6b563a64 100644 --- a/torax/tests/test_data/test_exact_finaltime.py +++ b/torax/tests/test_data/test_exact_finaltime.py @@ -17,7 +17,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -33,7 +35,6 @@ def get_config() -> config_lib.Config: resistivity_mult=100, # to shorten current diffusion time t_final=2, exact_t_final=True, - bootstrap_mult=0, # remove bootstrap current ), # set flat Ohmic current to provide larger range of current evolution for # test @@ -42,14 +43,6 @@ def get_config() -> config_lib.Config: predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -61,6 +54,20 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: return qlknn_wrapper.QLKNNTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -71,5 +78,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_explicit.py b/torax/tests/test_data/test_explicit.py index b880487a..86f0e07e 100644 --- a/torax/tests/test_data/test_explicit.py +++ b/torax/tests/test_data/test_explicit.py @@ -14,10 +14,13 @@ """Config for test_explicit. Basic test of explicit linear solver.""" +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.tests.test_lib import explicit_stepper from torax.transport_model import constant as constant_transport_model @@ -31,26 +34,10 @@ def get_config() -> config_lib.Config: ), numerics=config_lib.Numerics( dtmult=0.9, - Qei_mult=0, t_final=0.1, - bootstrap_mult=0, # remove bootstrap current ion_heat_eq=True, el_heat_eq=False, ), - Ptot=200.0e6, - # Do not use the fusion heat source. - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - generic_ion_el_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.FORMULA_BASED, - is_explicit=True, - ), - ), solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=False, @@ -66,6 +53,30 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['generic_ion_el_heat_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['generic_ion_el_heat_source'].runtime_params, + # total heating (including accounting for radiation) r + Ptot=200.0e6, # pylint: disable=unexpected-keyword-arg + is_explicit=True, + ) + ) + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -76,5 +87,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=explicit_stepper.ExplicitStepper, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_fixed_dt.py b/torax/tests/test_data/test_fixed_dt.py index 212bd67a..b68dbee4 100644 --- a/torax/tests/test_data/test_fixed_dt.py +++ b/torax/tests/test_data/test_fixed_dt.py @@ -17,7 +17,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.time_step_calculator import fixed_time_step_calculator from torax.transport_model import qlknn_wrapper @@ -31,17 +33,8 @@ def get_config() -> config_lib.Config: t_final=2, use_fixed_dt=True, fixed_dt=2e-2, - bootstrap_mult=0, # remove bootstrap current ), # Do not use the fusion heat source. - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, @@ -57,6 +50,20 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: return qlknn_wrapper.QLKNNTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -71,6 +78,7 @@ def get_sim() -> sim_lib.Sim: config=sim_config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), time_step_calculator=time_step_calculator, ) diff --git a/torax/tests/test_data/test_frozen_newton_raphson.py b/torax/tests/test_data/test_frozen_newton_raphson.py index 262b6769..f2888016 100644 --- a/torax/tests/test_data/test_frozen_newton_raphson.py +++ b/torax/tests/test_data/test_frozen_newton_raphson.py @@ -22,7 +22,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.tests.test_lib import sim_test_case from torax.transport_model import constant as constant_transport_model @@ -33,22 +35,12 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, t_final=1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=1.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -56,6 +48,22 @@ def get_geometry(config: config_lib.Config) -> geometry.Geometry: return geometry.build_circular_geometry(config) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -70,4 +78,5 @@ def get_sim() -> sim_lib.Sim: config=config, ), transport_model=constant_transport_model.ConstantTransportModel(), + source_models=get_sources(), ) diff --git a/torax/tests/test_data/test_frozen_optimizer.py b/torax/tests/test_data/test_frozen_optimizer.py index c4c84093..f4aab000 100644 --- a/torax/tests/test_data/test_frozen_optimizer.py +++ b/torax/tests/test_data/test_frozen_optimizer.py @@ -19,7 +19,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.tests.test_lib import sim_test_case from torax.transport_model import constant as constant_transport_model @@ -30,22 +32,12 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, t_final=1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=1.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -57,6 +49,22 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -73,4 +81,5 @@ def get_sim() -> sim_lib.Sim: transport_params=transport_model.runtime_params, ), transport_model=transport_model, + source_models=get_sources(), ) diff --git a/torax/tests/test_data/test_fusion_power.py b/torax/tests/test_data/test_fusion_power.py index e0341eb0..3c401b77 100644 --- a/torax/tests/test_data/test_fusion_power.py +++ b/torax/tests/test_data/test_fusion_power.py @@ -21,7 +21,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import critical_gradient as cgm_transport_model @@ -36,35 +38,21 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=1, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, - Ptot=53.0e6, # total external heating solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, d_per=0.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.MODEL_BASED, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -81,6 +69,28 @@ def get_transport_model() -> cgm_transport_model.CriticalGradientModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 1.0e22 + # total heating (including accounting for radiation) r + source_models.sources['generic_ion_el_heat_source'].runtime_params.Ptot = ( + 53.0e6 + ) + # total pellet particles/s + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -91,5 +101,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_implicit.py b/torax/tests/test_data/test_implicit.py index da5d60b2..47c8ede3 100644 --- a/torax/tests/test_data/test_implicit.py +++ b/torax/tests/test_data/test_implicit.py @@ -17,7 +17,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -28,22 +30,12 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, t_final=1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=1.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -55,6 +47,22 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -65,5 +73,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_implicit_short_optimizer.py b/torax/tests/test_data/test_implicit_short_optimizer.py index e53efc61..6279ae28 100644 --- a/torax/tests/test_data/test_implicit_short_optimizer.py +++ b/torax/tests/test_data/test_implicit_short_optimizer.py @@ -23,7 +23,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.tests.test_lib import sim_test_case from torax.transport_model import constant as constant_transport_model @@ -36,22 +38,12 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, t_final=0.1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, theta_imp=1.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -63,6 +55,22 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -79,4 +87,5 @@ def get_sim() -> sim_lib.Sim: transport_params=transport_model.runtime_params, ), transport_model=transport_model, + source_models=get_sources(), ) diff --git a/torax/tests/test_data/test_iterbaseline_mockup.py b/torax/tests/test_data/test_iterbaseline_mockup.py index 3c36bab9..8d84da73 100644 --- a/torax/tests/test_data/test_iterbaseline_mockup.py +++ b/torax/tests/test_data/test_iterbaseline_mockup.py @@ -14,10 +14,13 @@ """ITER baseline approximately based on Mantica PPCF 2021.""" +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -58,12 +61,6 @@ def get_config() -> config_lib.Config: # 1/multiplication factor for sigma (conductivity) to reduce current # diffusion timescale to be closer to heat diffusion timescale resistivity_mult=200, - # multiplier for ion-electron heat exchange term for sensitivity - # testing - Qei_mult=1, - # Multiplication factor for bootstrap current (note fbs~0.3 in - # original simu) - bootstrap_mult=1, # numerical (e.g. no. of grid points, other info needed by solver) nr=25, # radial grid points maxdt=0.5, @@ -81,55 +78,12 @@ def get_config() -> config_lib.Config: # condtion location if n != largeValue_n=1.0e8, ), - # external heat source parameters - w=0.07280908366127758, # Gaussian width in normalized radial coordinate - # Source Gaussian central location in normalized r - rsource=0.1383372589564274, - Ptot=8.0e6, # total heating (including accounting for radiation) - el_heat_fraction=0.0, # electron heating fraction - # particle source parameters - # pellets behave like a gas puff for this simulation with exponential - # decay therefore use the "puff" structure for pellets - # exponential decay length of gas puff ionization (normalized radial - # coordinate) - puff_decay_length=0.21, - S_puff_tot=2.14e22, # total pellet particles/s - # Gaussian width of pellet deposition (normalized radial coordinate) in - # continuous pellet model - pellet_width=0.1, - # Pellet source Gaussian central location (normalized radial coordinate) - # in continuous pellet model - pellet_deposition_location=0.85, - # total pellet particles/s (continuous pellet model) - S_pellet_tot=0.0e22, - # NBI particle source Gaussian width (normalized radial coordinate) - nbi_particle_width=0.25, - # NBI particle source Gaussian central location (normalized radial - # coordinate) - nbi_deposition_location=0.5, - S_nbi_tot=2.05e20, # NBI total particle source - # external current profiles - fext=0.09, # total "external" current fraction - # width of "external" Gaussian current profile (normalized radial - # coordinate) - wext=0.25, - # radius of "external" Gaussian current profile (normalized radial - # coordinate) - rext=0.35, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, chi_per=20, d_per=10, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.MODEL_BASED, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -174,6 +128,78 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + # Multiplication factor for bootstrap current (note fbs~0.3 in original simu) + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + source_models.jext.runtime_params = dataclasses.replace( + source_models.jext.runtime_params, + # total "external" current fraction + fext=0.09, + # width of "external" Gaussian current profile (normalized radial + # coordinate) + wext=0.25, + # radius of "external" Gaussian current profile (normalized radial + # coordinate) + rext=0.35, + ) + # pytype: disable=unexpected-keyword-arg + # pylint: disable=unexpected-keyword-arg + source_models.sources['generic_ion_el_heat_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['generic_ion_el_heat_source'].runtime_params, + rsource=0.1383372589564274, + # Gaussian width in normalized radial coordinate r + w=0.07280908366127758, + # total heating (including accounting for radiation) r + Ptot=8.0e6, + # electron heating fraction r + el_heat_fraction=0.0, + ) + ) + source_models.sources['gas_puff_source'].runtime_params = dataclasses.replace( + source_models.sources['gas_puff_source'].runtime_params, + # pellets behave like a gas puff for this simulation with exponential + # decay therefore use the puff structure for pellets exponential decay + # length of gas puff ionization (normalized radial coordinate) + puff_decay_length=0.21, + # total pellet particles/s + S_puff_tot=2.14e22, + ) + source_models.sources['pellet_source'].runtime_params = dataclasses.replace( + source_models.sources['pellet_source'].runtime_params, + # total pellet particles/s (continuous pellet model) + S_pellet_tot=0.0e22, + # Gaussian width of pellet deposition (normalized radial coordinate) in + # continuous pellet model + pellet_width=0.1, + # Pellet source Gaussian central location (normalized radial coordinate) + # in continuous pellet model. + pellet_deposition_location=0.85, + ) + source_models.sources['nbi_particle_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['nbi_particle_source'].runtime_params, + # NBI total particle source + S_nbi_tot=2.05e20, + # NBI particle source Gaussian central location (normalized radial + # coordinate) + nbi_deposition_location=0.5, + # NBI particle source Gaussian width (normalized radial coordinate) + nbi_particle_width=0.25, + ) + ) + # pytype: enable=unexpected-keyword-arg + # pylint: enable=unexpected-keyword-arg + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -184,5 +210,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_iterhybrid_mockup.py b/torax/tests/test_data/test_iterhybrid_mockup.py index c950a0ba..091b64bc 100644 --- a/torax/tests/test_data/test_iterhybrid_mockup.py +++ b/torax/tests/test_data/test_iterhybrid_mockup.py @@ -14,10 +14,13 @@ """ITER hybrid scenario approximately based on van Mulders NF 2021.""" +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -60,11 +63,6 @@ def get_config() -> config_lib.Config: # 1/multiplication factor for sigma (conductivity) to reduce current # diffusion timescale to be closer to heat diffusion timescale. resistivity_mult=200, - # multiplier for ion-electron heat exchange term for sensitivity - Qei_mult=1, - # Multiplication factor for bootstrap current (note fbs~0.3 in - # original simu) - bootstrap_mult=1, # numerical (e.g. no. of grid points, other info needed by solver) nr=25, # radial grid points ion_heat_eq=True, @@ -83,40 +81,6 @@ def get_config() -> config_lib.Config: # condtion location if n != neped largeValue_n=1.0e8, ), - # external heat source parameters - w=0.07280908366127758, # Gaussian width in normalized radial coordinate - rsource=0.12741589640723575, # Source Gauss peak in normalized r - Ptot=51.0e6, # total heating (including accounting for radiation) - el_heat_fraction=0.68, # electron heating fraction - # particle source parameters - # pellets behave like a gas puff for this simulation with exponential - # decay therefore use the "puff" structure for pellets - # exponential decay length of gas puff ionization (normalized radial - # coordinate) - puff_decay_length=0.3, - S_puff_tot=6.0e21, # total pellet particles/s - # Gaussian width of pellet deposition (normalized radial coordinate) in - # continuous pellet model - pellet_width=0.1, - # Pellet source Gaussian central location (normalized radial coordinate) - # in continuous pellet model - pellet_deposition_location=0.85, - # total pellet particles/s (continuous pellet model) - S_pellet_tot=0.0e22, - # NBI particle source Gaussian width (normalized radial coordinate) - nbi_particle_width=0.25, - # NBI particle source Gaussian central location (normalized radial - # coordinate) - nbi_deposition_location=0.3, - S_nbi_tot=2.05e20, # NBI total particle source - # external current profiles - fext=0.46, # total "external" current fraction - # width of "external" Gaussian current profile (normalized radial - # coordinate) - wext=0.075, - # radius of "external" Gaussian current profile (normalized radial - # coordinate) - rext=0.36, solver=config_lib.SolverConfig( predictor_corrector=False, # (deliberately) large heat conductivity for Pereverzev rule @@ -125,15 +89,6 @@ def get_config() -> config_lib.Config: d_per=15, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - # incorporate fusion heating source in calculation. - source_type=source_config.SourceType.MODEL_BASED, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -180,6 +135,78 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + # Multiplication factor for bootstrap current (note fbs~0.3 in original simu) + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + source_models.jext.runtime_params = dataclasses.replace( + source_models.jext.runtime_params, + # total "external" current fraction + fext=0.46, + # width of "external" Gaussian current profile (normalized radial + # coordinate) + wext=0.075, + # radius of "external" Gaussian current profile (normalized radial + # coordinate) + rext=0.36, + ) + # pytype: disable=unexpected-keyword-arg + # pylint: disable=unexpected-keyword-arg + source_models.sources['generic_ion_el_heat_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['generic_ion_el_heat_source'].runtime_params, + rsource=0.12741589640723575, + # Gaussian width in normalized radial coordinate r + w=0.07280908366127758, + # total heating (including accounting for radiation) r + Ptot=51.0e6, + # electron heating fraction r + el_heat_fraction=0.68, + ) + ) + source_models.sources['gas_puff_source'].runtime_params = dataclasses.replace( + source_models.sources['gas_puff_source'].runtime_params, + # pellets behave like a gas puff for this simulation with exponential + # decay therefore use the puff structure for pellets exponential decay + # length of gas puff ionization (normalized radial coordinate) + puff_decay_length=0.3, + # total pellet particles/s + S_puff_tot=6.0e21, + ) + source_models.sources['pellet_source'].runtime_params = dataclasses.replace( + source_models.sources['pellet_source'].runtime_params, + # total pellet particles/s (continuous pellet model) + S_pellet_tot=0.0e22, + # Gaussian width of pellet deposition (normalized radial coordinate) in + # continuous pellet model + pellet_width=0.1, + # Pellet source Gaussian central location (normalized radial coordinate) + # in continuous pellet model. + pellet_deposition_location=0.85, + ) + source_models.sources['nbi_particle_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['nbi_particle_source'].runtime_params, + # NBI total particle source + S_nbi_tot=2.05e20, + # NBI particle source Gaussian central location (normalized radial + # coordinate) + nbi_deposition_location=0.3, + # NBI particle source Gaussian width (normalized radial coordinate) + nbi_particle_width=0.25, + ) + ) + # pytype: enable=unexpected-keyword-arg + # pylint: enable=unexpected-keyword-arg + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -190,5 +217,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_iterhybrid_newton.py b/torax/tests/test_data/test_iterhybrid_newton.py index ab822f0d..d2a563a8 100644 --- a/torax/tests/test_data/test_iterhybrid_newton.py +++ b/torax/tests/test_data/test_iterhybrid_newton.py @@ -18,10 +18,13 @@ With Newton-Raphson stepper and adaptive timestep (backtracking) """ +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import nonlinear_theta_method from torax.transport_model import qlknn_wrapper @@ -64,11 +67,6 @@ def get_config() -> config_lib.Config: # 1/multiplication factor for sigma (conductivity) to reduce current # diffusion timescale to be closer to heat diffusion timescale. resistivity_mult=1, - # multiplier for ion-electron heat exchange term for sensitivity - Qei_mult=1, - # Multiplication factor for bootstrap current (note fbs~0.3 in - # original simu) - bootstrap_mult=1, # numerical (e.g. no. of grid points, other info needed by solver) nr=25, # radial grid points ion_heat_eq=True, @@ -87,40 +85,6 @@ def get_config() -> config_lib.Config: # condtion location if n != neped largeValue_n=1.0e8, ), - # external heat source parameters - w=0.07280908366127758, # Gaussian width in normalized radial coordinate - rsource=0.12741589640723575, # Source Gauss peak in normalized r - Ptot=51.0e6, # total heating (including accounting for radiation) - el_heat_fraction=0.68, # electron heating fraction - # particle source parameters - # pellets behave like a gas puff for this simulation with exponential - # decay therefore use the "puff" structure for pellets - # exponential decay length of gas puff ionization (normalized radial - # coordinate) - puff_decay_length=0.3, - S_puff_tot=6.0e21, # total pellet particles/s - # Gaussian width of pellet deposition (normalized radial coordinate) in - # continuous pellet model - pellet_width=0.1, - # Pellet source Gaussian central location (normalized radial coordinate) - # in continuous pellet model - pellet_deposition_location=0.85, - # total pellet particles/s (continuous pellet model) - S_pellet_tot=0.0e22, - # NBI particle source Gaussian width (normalized radial coordinate) - nbi_particle_width=0.25, - # NBI particle source Gaussian central location (normalized radial - # coordinate) - nbi_deposition_location=0.3, - S_nbi_tot=2.05e20, # NBI total particle source - # external current profiles - fext=0.46, # total "external" current fraction - # width of "external" Gaussian current profile (normalized radial - # coordinate) - wext=0.075, - # radius of "external" Gaussian current profile (normalized radial - # coordinate) - rext=0.36, solver=config_lib.SolverConfig( predictor_corrector=False, convection_dirichlet_mode='ghost', @@ -132,15 +96,6 @@ def get_config() -> config_lib.Config: use_pereverzev=True, log_iterations=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - # incorporate fusion heating source in calculation. - source_type=source_config.SourceType.MODEL_BASED, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -196,6 +151,78 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + # Multiplication factor for bootstrap current (note fbs~0.3 in original simu) + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + source_models.jext.runtime_params = dataclasses.replace( + source_models.jext.runtime_params, + # total "external" current fraction + fext=0.46, + # width of "external" Gaussian current profile (normalized radial + # coordinate) + wext=0.075, + # radius of "external" Gaussian current profile (normalized radial + # coordinate) + rext=0.36, + ) + # pytype: disable=unexpected-keyword-arg + # pylint: disable=unexpected-keyword-arg + source_models.sources['generic_ion_el_heat_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['generic_ion_el_heat_source'].runtime_params, + rsource=0.12741589640723575, + # Gaussian width in normalized radial coordinate r + w=0.07280908366127758, + # total heating (including accounting for radiation) r + Ptot=51.0e6, + # electron heating fraction r + el_heat_fraction=0.68, + ) + ) + source_models.sources['gas_puff_source'].runtime_params = dataclasses.replace( + source_models.sources['gas_puff_source'].runtime_params, + # pellets behave like a gas puff for this simulation with exponential + # decay therefore use the puff structure for pellets exponential decay + # length of gas puff ionization (normalized radial coordinate) + puff_decay_length=0.3, + # total pellet particles/s + S_puff_tot=6.0e21, + ) + source_models.sources['pellet_source'].runtime_params = dataclasses.replace( + source_models.sources['pellet_source'].runtime_params, + # total pellet particles/s (continuous pellet model) + S_pellet_tot=0.0e22, + # Gaussian width of pellet deposition (normalized radial coordinate) in + # continuous pellet model + pellet_width=0.1, + # Pellet source Gaussian central location (normalized radial coordinate) + # in continuous pellet model. + pellet_deposition_location=0.85, + ) + source_models.sources['nbi_particle_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['nbi_particle_source'].runtime_params, + # NBI total particle source + S_nbi_tot=2.05e20, + # NBI particle source Gaussian central location (normalized radial + # coordinate) + nbi_deposition_location=0.3, + # NBI particle source Gaussian width (normalized radial coordinate) + nbi_particle_width=0.25, + ) + ) + # pytype: enable=unexpected-keyword-arg + # pylint: enable=unexpected-keyword-arg + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -206,5 +233,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=nonlinear_theta_method.NewtonRaphsonThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector.py b/torax/tests/test_data/test_iterhybrid_predictor_corrector.py index e0a02c11..bf612e1a 100644 --- a/torax/tests/test_data/test_iterhybrid_predictor_corrector.py +++ b/torax/tests/test_data/test_iterhybrid_predictor_corrector.py @@ -14,10 +14,13 @@ """ITER hybrid scenario based (roughly) on van Mulders NF 2021.""" +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -60,11 +63,6 @@ def get_config() -> config_lib.Config: # 1/multiplication factor for sigma (conductivity) to reduce current # diffusion timescale to be closer to heat diffusion timescale. resistivity_mult=200, - # multiplier for ion-electron heat exchange term for sensitivity - Qei_mult=1, - # Multiplication factor for bootstrap current (note fbs~0.3 in - # original simu) - bootstrap_mult=1, # numerical (e.g. no. of grid points, other info needed by solver) nr=25, # radial grid points ion_heat_eq=True, @@ -83,40 +81,6 @@ def get_config() -> config_lib.Config: # condtion location if n != neped largeValue_n=1.0e8, ), - # external heat source parameters - w=0.07280908366127758, # Gaussian width in normalized radial coordinate - rsource=0.12741589640723575, # Source Gauss peak in normalized r - Ptot=51.0e6, # total heating (including accounting for radiation) - el_heat_fraction=0.68, # electron heating fraction - # particle source parameters - # pellets behave like a gas puff for this simulation with exponential - # decay therefore use the "puff" structure for pellets - # exponential decay length of gas puff ionization (normalized radial - # coordinate) - puff_decay_length=0.3, - S_puff_tot=6.0e21, # total pellet particles/s - # Gaussian width of pellet deposition (normalized radial coordinate) in - # continuous pellet model - pellet_width=0.1, - # Pellet source Gaussian central location (normalized radial coordinate) - # in continuous pellet model - pellet_deposition_location=0.85, - # total pellet particles/s (continuous pellet model) - S_pellet_tot=0.0e22, - # NBI particle source Gaussian width (normalized radial coordinate) - nbi_particle_width=0.25, - # NBI particle source Gaussian central location (normalized radial - # coordinate) - nbi_deposition_location=0.3, - S_nbi_tot=2.05e20, # NBI total particle source - # external current profiles - fext=0.46, # total "external" current fraction - # width of "external" Gaussian current profile (normalized radial - # coordinate) - wext=0.075, - # radius of "external" Gaussian current profile (normalized radial - # coordinate) - rext=0.36, solver=config_lib.SolverConfig( predictor_corrector=True, corrector_steps=1, @@ -126,15 +90,6 @@ def get_config() -> config_lib.Config: d_per=15, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - # incorporate fusion heating source in calculation. - source_type=source_config.SourceType.MODEL_BASED, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -181,6 +136,78 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + # Multiplication factor for bootstrap current (note fbs~0.3 in original simu) + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + source_models.jext.runtime_params = dataclasses.replace( + source_models.jext.runtime_params, + # total "external" current fraction + fext=0.46, + # width of "external" Gaussian current profile (normalized radial + # coordinate) + wext=0.075, + # radius of "external" Gaussian current profile (normalized radial + # coordinate) + rext=0.36, + ) + # pytype: disable=unexpected-keyword-arg + # pylint: disable=unexpected-keyword-arg + source_models.sources['generic_ion_el_heat_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['generic_ion_el_heat_source'].runtime_params, + rsource=0.12741589640723575, + # Gaussian width in normalized radial coordinate r + w=0.07280908366127758, + # total heating (including accounting for radiation) r + Ptot=51.0e6, + # electron heating fraction r + el_heat_fraction=0.68, + ) + ) + source_models.sources['gas_puff_source'].runtime_params = dataclasses.replace( + source_models.sources['gas_puff_source'].runtime_params, + # pellets behave like a gas puff for this simulation with exponential + # decay therefore use the puff structure for pellets exponential decay + # length of gas puff ionization (normalized radial coordinate) + puff_decay_length=0.3, + # total pellet particles/s + S_puff_tot=6.0e21, + ) + source_models.sources['pellet_source'].runtime_params = dataclasses.replace( + source_models.sources['pellet_source'].runtime_params, + # total pellet particles/s (continuous pellet model) + S_pellet_tot=0.0e22, + # Gaussian width of pellet deposition (normalized radial coordinate) in + # continuous pellet model + pellet_width=0.1, + # Pellet source Gaussian central location (normalized radial coordinate) + # in continuous pellet model. + pellet_deposition_location=0.85, + ) + source_models.sources['nbi_particle_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['nbi_particle_source'].runtime_params, + # NBI total particle source + S_nbi_tot=2.05e20, + # NBI particle source Gaussian central location (normalized radial + # coordinate) + nbi_deposition_location=0.3, + # NBI particle source Gaussian width (normalized radial coordinate) + nbi_particle_width=0.25, + ) + ) + # pytype: enable=unexpected-keyword-arg + # pylint: enable=unexpected-keyword-arg + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -191,5 +218,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_iterhybrid_rampup.py b/torax/tests/test_data/test_iterhybrid_rampup.py index 8591caab..ce636ece 100644 --- a/torax/tests/test_data/test_iterhybrid_rampup.py +++ b/torax/tests/test_data/test_iterhybrid_rampup.py @@ -18,10 +18,13 @@ With Newton-Raphson stepper and adaptive timestep (backtracking) """ +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import nonlinear_theta_method from torax.time_step_calculator import fixed_time_step_calculator from torax.transport_model import qlknn_wrapper @@ -70,10 +73,7 @@ def get_config() -> config_lib.Config: # 1/multiplication factor for sigma (conductivity) to reduce current # diffusion timescale to be closer to heat diffusion timescale. resistivity_mult=1, - # multiplier for ion-electron heat exchange term for sensitivity - Qei_mult=1, # Multiplication factor for bootstrap current - bootstrap_mult=1, # numerical (e.g. no. of grid points, other info needed by solver) nr=25, # radial grid points ion_heat_eq=True, @@ -92,40 +92,6 @@ def get_config() -> config_lib.Config: # condtion location if n != neped largeValue_n=1.0e8, ), - # external heat source parameters - w=0.07280908366127758, # Gaussian width in normalized radial coordinate - rsource=0.12741589640723575, # Source Gauss peak in normalized r - Ptot=20.0e6, # total heating - el_heat_fraction=1.0, # electron heating fraction - # particle source parameters - # pellets behave like a gas puff for this simulation with exponential - # decay therefore use the "puff" structure for pellets - # exponential decay length of gas puff ionization (normalized radial - # coordinate) - puff_decay_length=0.3, - S_puff_tot=0.0e21, # total pellet particles/s - # Gaussian width of pellet deposition (normalized radial coordinate) in - # continuous pellet model - pellet_width=0.1, - # Pellet source Gaussian central location (normalized radial coordinate) - # in continuous pellet model - pellet_deposition_location=0.85, - # total pellet particles/s (continuous pellet model) - S_pellet_tot=0.0e22, - # NBI particle source Gaussian width (normalized radial coordinate) - nbi_particle_width=0.25, - # NBI particle source Gaussian central location (normalized radial - # coordinate) - nbi_deposition_location=0.3, - S_nbi_tot=0.0e20, # NBI total particle source - # external current profiles - fext=0.15, # total "external" current fraction - # width of "external" Gaussian current profile (normalized radial - # coordinate) - wext=0.075, - # radius of "external" Gaussian current profile (normalized radial - # coordinate) - rext=0.36, solver=config_lib.SolverConfig( predictor_corrector=True, corrector_steps=10, @@ -137,15 +103,6 @@ def get_config() -> config_lib.Config: use_pereverzev=True, log_iterations=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - # incorporate fusion heating source in calculation. - source_type=source_config.SourceType.MODEL_BASED, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -202,6 +159,75 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + source_models.jext.runtime_params = dataclasses.replace( + source_models.jext.runtime_params, + # total "external" current fraction + fext=0.15, + # width of "external" Gaussian current profile (normalized radial + # coordinate) + wext=0.075, + # radius of "external" Gaussian current profile (normalized radial + # coordinate) + rext=0.36, + ) + # pytype: disable=unexpected-keyword-arg + # pylint: disable=unexpected-keyword-arg + source_models.sources['generic_ion_el_heat_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['generic_ion_el_heat_source'].runtime_params, + rsource=0.12741589640723575, + # Gaussian width in normalized radial coordinate r + w=0.07280908366127758, + # total heating (including accounting for radiation) r + Ptot=20.0e6, + # electron heating fraction r + el_heat_fraction=1.0, + ) + ) + source_models.sources['gas_puff_source'].runtime_params = dataclasses.replace( + source_models.sources['gas_puff_source'].runtime_params, + # pellets behave like a gas puff for this simulation with exponential + # decay therefore use the puff structure for pellets exponential decay + # length of gas puff ionization (normalized radial coordinate) + puff_decay_length=0.3, + # total pellet particles/s + S_puff_tot=0.0, + ) + source_models.sources['pellet_source'].runtime_params = dataclasses.replace( + source_models.sources['pellet_source'].runtime_params, + # total pellet particles/s (continuous pellet model) + S_pellet_tot=0.0e22, + # Gaussian width of pellet deposition (normalized radial coordinate) in + # continuous pellet model + pellet_width=0.1, + # Pellet source Gaussian central location (normalized radial coordinate) + # in continuous pellet model. + pellet_deposition_location=0.85, + ) + source_models.sources['nbi_particle_source'].runtime_params = ( + dataclasses.replace( + source_models.sources['nbi_particle_source'].runtime_params, + # NBI total particle source + S_nbi_tot=0.0, + # NBI particle source Gaussian central location (normalized radial + # coordinate) + nbi_deposition_location=0.3, + # NBI particle source Gaussian width (normalized radial coordinate) + nbi_particle_width=0.25, + ) + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: config = get_config() geo = get_geometry(config) @@ -209,6 +235,7 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=nonlinear_theta_method.NewtonRaphsonThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), time_step_calculator=fixed_time_step_calculator.FixedTimeStepCalculator(), ) diff --git a/torax/tests/test_data/test_ne_qlknn_deff_veff.py b/torax/tests/test_data/test_ne_qlknn_deff_veff.py index 264c7e17..ae26d42d 100644 --- a/torax/tests/test_data/test_ne_qlknn_deff_veff.py +++ b/torax/tests/test_data/test_ne_qlknn_deff_veff.py @@ -22,7 +22,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -36,34 +38,20 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=2, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - w=0.18202270915319393, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -83,6 +71,31 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 1.0e22 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -93,5 +106,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_ne_qlknn_defromchie.py b/torax/tests/test_data/test_ne_qlknn_defromchie.py index 9a4590e8..8ad0dcd4 100644 --- a/torax/tests/test_data/test_ne_qlknn_defromchie.py +++ b/torax/tests/test_data/test_ne_qlknn_defromchie.py @@ -22,7 +22,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -36,34 +38,20 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=2, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - w=0.18202270915319393, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -83,6 +71,31 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 1.0e22 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -93,5 +106,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_newton_raphson_zeroiter.py b/torax/tests/test_data/test_newton_raphson_zeroiter.py index 4837bac7..47896bae 100644 --- a/torax/tests/test_data/test_newton_raphson_zeroiter.py +++ b/torax/tests/test_data/test_newton_raphson_zeroiter.py @@ -29,7 +29,8 @@ from torax import geometry from torax import sim as sim_lib from torax.fvm import enums -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.stepper import nonlinear_theta_method from torax.stepper import stepper as stepper_lib @@ -79,7 +80,6 @@ def get_config() -> config_lib.Config: # to shorten current diffusion time for the test resistivity_mult=100, t_final=2, - bootstrap_mult=0, # remove bootstrap current ), # set flat Ohmic current to provide larger range of current evolution for # test @@ -88,14 +88,6 @@ def get_config() -> config_lib.Config: predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -107,6 +99,20 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: return qlknn_wrapper.QLKNNTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -117,5 +123,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=make_linear_newton_raphson_stepper, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_ohmic_power.py b/torax/tests/test_data/test_ohmic_power.py index 8478e04f..42994025 100644 --- a/torax/tests/test_data/test_ohmic_power.py +++ b/torax/tests/test_data/test_ohmic_power.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -35,20 +37,10 @@ def get_config() -> config_lib.Config: numerics=config_lib.Numerics( t_final=1, resistivity_mult=100, - bootstrap_mult=0, # remove bootstrap current ), - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -64,6 +56,27 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -74,5 +87,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_optimizer_zeroiter.py b/torax/tests/test_data/test_optimizer_zeroiter.py index 2ed8471e..446cd17c 100644 --- a/torax/tests/test_data/test_optimizer_zeroiter.py +++ b/torax/tests/test_data/test_optimizer_zeroiter.py @@ -23,7 +23,8 @@ from torax import geometry from torax import sim as sim_lib from torax.fvm import enums -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.stepper import nonlinear_theta_method from torax.stepper import stepper as stepper_lib @@ -72,7 +73,6 @@ def get_config() -> config_lib.Config: # to shorten current diffusion time for the test resistivity_mult=100, t_final=2, - bootstrap_mult=0, # remove bootstrap current ), # set flat Ohmic current to provide larger range of current evolution for # test @@ -81,14 +81,6 @@ def get_config() -> config_lib.Config: predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -100,6 +92,20 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: return qlknn_wrapper.QLKNNTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -110,5 +116,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=make_linear_optimizer_stepper, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_particle_sources_cgm.py b/torax/tests/test_data/test_particle_sources_cgm.py index cfd7b00e..3a284d8d 100644 --- a/torax/tests/test_data/test_particle_sources_cgm.py +++ b/torax/tests/test_data/test_particle_sources_cgm.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import critical_gradient as cgm_transport_model @@ -34,34 +36,21 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=2, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, d_per=0.0, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -78,6 +67,27 @@ def get_transport_model() -> cgm_transport_model.CriticalGradientModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 1.0e22 + # total pellet particles/s + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -88,5 +98,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_particle_sources_constant.py b/torax/tests/test_data/test_particle_sources_constant.py index 2fb5d702..97456bb9 100644 --- a/torax/tests/test_data/test_particle_sources_constant.py +++ b/torax/tests/test_data/test_particle_sources_constant.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -32,33 +34,20 @@ def get_config() -> config_lib.Config: nbar=0.85, # initial density (Greenwald fraction units) ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, # to shorten current diffusion time for the test resistivity_mult=100, - bootstrap_mult=1, # remove bootstrap current t_final=2, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - S_pellet_tot=2.0e22, - S_puff_tot=1.0e22, - S_nbi_tot=0.0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -77,6 +66,27 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 2.0e22 + # total pellet particles/s + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = 1.0e22 + # NBI total particle source + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -87,5 +97,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_pc_method_ne.py b/torax/tests/test_data/test_pc_method_ne.py index 615a0435..2cb315ef 100644 --- a/torax/tests/test_data/test_pc_method_ne.py +++ b/torax/tests/test_data/test_pc_method_ne.py @@ -19,10 +19,13 @@ scaled from chi_e """ +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -38,33 +41,21 @@ def get_config() -> config_lib.Config: neped=1.0, ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, # to shorten current diffusion time for the test resistivity_mult=100, - bootstrap_mult=1, # remove bootstrap current t_final=2.0, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - w=0.18202270915319393, - S_pellet_tot=1.0e22, - S_puff_tot=0.5e22, - S_nbi_tot=0.3e22, - Ptot=53.0e6, # total external heating solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -84,6 +75,35 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # pylint: disable=unexpected-keyword-arg + source_models.sources["generic_ion_el_heat_source"].runtime_params = ( + dataclasses.replace( + source_models.sources["generic_ion_el_heat_source"].runtime_params, + # Gaussian width in normalized radial coordinate r + w=0.18202270915319393, + # total heating (including accounting for radiation) r + Ptot=53.0e6, + ) + ) + # pylint: enable=unexpected-keyword-arg + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 1.0e22 + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0.5e22 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.3e22 + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -94,5 +114,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_pedestal.py b/torax/tests/test_data/test_pedestal.py index 3ed12fc5..4010212c 100644 --- a/torax/tests/test_data/test_pedestal.py +++ b/torax/tests/test_data/test_pedestal.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -29,19 +31,10 @@ def get_config() -> config_lib.Config: return config_lib.Config( numerics=config_lib.Numerics( t_final=1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -53,6 +46,20 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -63,5 +70,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_prescribed_timedependent_ne.py b/torax/tests/test_data/test_prescribed_timedependent_ne.py index 30f915c4..ca8430d6 100644 --- a/torax/tests/test_data/test_prescribed_timedependent_ne.py +++ b/torax/tests/test_data/test_prescribed_timedependent_ne.py @@ -19,10 +19,13 @@ pedestal, mocking up current-overshoot and an LH transition """ +import dataclasses from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -39,29 +42,15 @@ def get_config() -> config_lib.Config: numerics=config_lib.Numerics( current_eq=True, resistivity_mult=50, # to shorten current diffusion time for the test - bootstrap_mult=0, # remove bootstrap current dtmult=150, maxdt=0.5, t_final=10, enable_prescribed_profile_evolution=True, ), - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, - Ptot={0: 20e6, 9: 20e6, 10: 120e6, 15: 120e6}, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -84,6 +73,41 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # pylint: disable=unexpected-keyword-arg + source_models.sources["generic_ion_el_heat_source"].runtime_params = ( + dataclasses.replace( + source_models.sources["generic_ion_el_heat_source"].runtime_params, + # Gaussian width in normalized radial coordinate r + w=0.18202270915319393, + # total heating (including accounting for radiation) r + Ptot={ + 0: 20e6, + 9: 20e6, + 10: 120e6, + 15: 120e6, + }, # in W + ) + ) + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -94,5 +118,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_psi_and_heat.py b/torax/tests/test_data/test_psi_and_heat.py index 1fec834a..b029ffcc 100644 --- a/torax/tests/test_data/test_psi_and_heat.py +++ b/torax/tests/test_data/test_psi_and_heat.py @@ -21,7 +21,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -36,7 +38,6 @@ def get_config() -> config_lib.Config: current_eq=True, resistivity_mult=100, # to shorten current diffusion time t_final=2, - bootstrap_mult=0, # remove bootstrap current ), # set flat Ohmic current to provide larger range of current evolution for # test @@ -45,14 +46,6 @@ def get_config() -> config_lib.Config: predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -64,6 +57,20 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: return qlknn_wrapper.QLKNNTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -74,5 +81,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_psi_heat_dens.py b/torax/tests/test_data/test_psi_heat_dens.py index e6853fa1..a238edb3 100644 --- a/torax/tests/test_data/test_psi_heat_dens.py +++ b/torax/tests/test_data/test_psi_heat_dens.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -32,32 +34,19 @@ def get_config() -> config_lib.Config: nbar=0.85, # initial density (in Greenwald fraction units) ), numerics=config_lib.Numerics( - Qei_mult=1, ion_heat_eq=True, el_heat_eq=True, dens_eq=True, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=1, # remove bootstrap current t_final=2, ), # set flat Ohmic current to provide larger range of current evolution for # test nu=0, - S_pellet_tot=0.0, - S_puff_tot=0.0, - S_nbi_tot=0.0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -76,6 +65,27 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 1.0 + source_models.j_bootstrap.runtime_params.bootstrap_mult = 1.0 + # total pellet particles/s (continuous pellet model) + source_models.sources['pellet_source'].runtime_params.S_pellet_tot = 0.0 + # total pellet particles/s + source_models.sources['gas_puff_source'].runtime_params.S_puff_tot = 0.0 + # NBI total particle source + source_models.sources['nbi_particle_source'].runtime_params.S_nbi_tot = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -86,5 +96,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_psichease_ip_chease.py b/torax/tests/test_data/test_psichease_ip_chease.py index 2c2161e0..e38ace80 100644 --- a/torax/tests/test_data/test_psichease_ip_chease.py +++ b/torax/tests/test_data/test_psichease_ip_chease.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -31,29 +33,15 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, ion_heat_eq=False, el_heat_eq=False, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=0, # remove bootstrap current t_final=3, ), - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -69,6 +57,32 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -79,5 +93,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_psichease_ip_parameters.py b/torax/tests/test_data/test_psichease_ip_parameters.py index 6bbc092f..296d8f27 100644 --- a/torax/tests/test_data/test_psichease_ip_parameters.py +++ b/torax/tests/test_data/test_psichease_ip_parameters.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -31,29 +33,15 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, ion_heat_eq=False, el_heat_eq=False, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=0, # remove bootstrap current t_final=3, ), - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -69,6 +57,32 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -79,5 +93,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_psichease_prescribed_johm.py b/torax/tests/test_data/test_psichease_prescribed_johm.py index f27f9bb5..f4757bc0 100644 --- a/torax/tests/test_data/test_psichease_prescribed_johm.py +++ b/torax/tests/test_data/test_psichease_prescribed_johm.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -31,32 +33,18 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, ion_heat_eq=False, el_heat_eq=False, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=0, # remove bootstrap current t_final=3, ), initial_psi_from_j=True, initial_j_is_total_current=False, nu=2, - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -72,6 +60,32 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -82,5 +96,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_psichease_prescribed_jtot.py b/torax/tests/test_data/test_psichease_prescribed_jtot.py index a0d32af8..d39737fc 100644 --- a/torax/tests/test_data/test_psichease_prescribed_jtot.py +++ b/torax/tests/test_data/test_psichease_prescribed_jtot.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -31,32 +33,18 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, ion_heat_eq=False, el_heat_eq=False, current_eq=True, resistivity_mult=100, # to shorten current diffusion time - bootstrap_mult=0, # remove bootstrap current t_final=3, ), initial_psi_from_j=True, initial_j_is_total_current=True, nu=2, - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -72,6 +60,32 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -82,5 +96,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_psiequation.py b/torax/tests/test_data/test_psiequation.py index 5f724378..7a2de056 100644 --- a/torax/tests/test_data/test_psiequation.py +++ b/torax/tests/test_data/test_psiequation.py @@ -17,7 +17,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -28,13 +30,11 @@ def get_config() -> config_lib.Config: set_pedestal=False, ), numerics=config_lib.Numerics( - Qei_mult=0, ion_heat_eq=False, el_heat_eq=False, current_eq=True, resistivity_mult=100, # to shorten current diffusion time t_final=3, - bootstrap_mult=0, # remove bootstrap current ), # set flat Ohmic current to provide larger range of current evolution for # test @@ -42,14 +42,6 @@ def get_config() -> config_lib.Config: solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -61,6 +53,22 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # multiplier for ion-electron heat exchange term for sensitivity + source_models.qei_source.runtime_params.Qei_mult = 0.0 + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -71,5 +79,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_qei.py b/torax/tests/test_data/test_qei.py index cd668d2d..736e3ad9 100644 --- a/torax/tests/test_data/test_qei.py +++ b/torax/tests/test_data/test_qei.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -33,19 +35,10 @@ def get_config() -> config_lib.Config: ), numerics=config_lib.Numerics( t_final=1, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -57,6 +50,20 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -67,5 +74,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_qei_chease_highdens.py b/torax/tests/test_data/test_qei_chease_highdens.py index b91b7781..a29cac21 100644 --- a/torax/tests/test_data/test_qei_chease_highdens.py +++ b/torax/tests/test_data/test_qei_chease_highdens.py @@ -20,7 +20,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import constant as constant_transport_model @@ -33,23 +35,10 @@ def get_config() -> config_lib.Config: ), numerics=config_lib.Numerics( t_final=1, - bootstrap_mult=0, # remove bootstrap current ), - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, solver=config_lib.SolverConfig( predictor_corrector=False, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -65,6 +54,30 @@ def get_transport_model() -> constant_transport_model.ConstantTransportModel: return constant_transport_model.ConstantTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -75,5 +88,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_qlknnheat.py b/torax/tests/test_data/test_qlknnheat.py index 5d4ab78c..5944a316 100644 --- a/torax/tests/test_data/test_qlknnheat.py +++ b/torax/tests/test_data/test_qlknnheat.py @@ -21,7 +21,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -30,20 +32,11 @@ def get_config() -> config_lib.Config: return config_lib.Config( numerics=config_lib.Numerics( t_final=2, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -55,6 +48,20 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: return qlknn_wrapper.QLKNNTransportModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -65,5 +72,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_semiimplicit_convection.py b/torax/tests/test_data/test_semiimplicit_convection.py index 68954d68..2b378af4 100644 --- a/torax/tests/test_data/test_semiimplicit_convection.py +++ b/torax/tests/test_data/test_semiimplicit_convection.py @@ -22,7 +22,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import critical_gradient as cgm_transport_model @@ -38,7 +40,6 @@ def get_config() -> config_lib.Config: ), numerics=config_lib.Numerics( t_final=0.5, - bootstrap_mult=0, # remove bootstrap current ), solver=config_lib.SolverConfig( predictor_corrector=False, @@ -48,14 +49,6 @@ def get_config() -> config_lib.Config: convection_neumann_mode='semi-implicit', use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -67,6 +60,20 @@ def get_transport_model() -> cgm_transport_model.CriticalGradientModel: return cgm_transport_model.CriticalGradientModel() +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + source_models.sources['fusion_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources['ohmic_heat_source'].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -77,5 +84,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_data/test_timedependence.py b/torax/tests/test_data/test_timedependence.py index 64ee9200..44fb36e6 100644 --- a/torax/tests/test_data/test_timedependence.py +++ b/torax/tests/test_data/test_timedependence.py @@ -22,7 +22,9 @@ from torax import config as config_lib from torax import geometry from torax import sim as sim_lib -from torax.sources import source_config +from torax.sources import default_sources +from torax.sources import runtime_params as source_runtime_params +from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method from torax.transport_model import qlknn_wrapper @@ -39,29 +41,15 @@ def get_config() -> config_lib.Config: numerics=config_lib.Numerics( current_eq=True, resistivity_mult=50, # to shorten current diffusion time for the test - bootstrap_mult=0, # remove bootstrap current dtmult=150, maxdt=0.5, t_final=10, enable_prescribed_profile_evolution=False, ), - w=0.18202270915319393, - S_pellet_tot=0, - S_puff_tot=0, - S_nbi_tot=0, - Ptot={0: 20e6, 9: 20e6, 10: 120e6, 15: 120e6}, solver=config_lib.SolverConfig( predictor_corrector=False, use_pereverzev=True, ), - sources=dict( - fusion_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ohmic_heat_source=source_config.SourceConfig( - source_type=source_config.SourceType.ZERO, - ), - ), ) @@ -84,6 +72,36 @@ def get_transport_model() -> qlknn_wrapper.QLKNNTransportModel: ) +def get_sources() -> source_models_lib.SourceModels: + """Returns the source models used in the simulation.""" + source_models = default_sources.get_default_sources() + # remove bootstrap current + source_models.j_bootstrap.runtime_params.bootstrap_mult = 0.0 + # total pellet particles/s (continuous pellet model) + source_models.sources["pellet_source"].runtime_params.S_pellet_tot = 0.0 + # Gaussian width in normalized radial coordinate r + source_models.sources["generic_ion_el_heat_source"].runtime_params.w = ( + 0.18202270915319393 + ) + source_models.sources["generic_ion_el_heat_source"].runtime_params.Ptot = { + 0: 20e6, + 9: 20e6, + 10: 120e6, + 15: 120e6, + } # in W + # total pellet particles/s + source_models.sources["gas_puff_source"].runtime_params.S_puff_tot = 0 + # NBI total particle source + source_models.sources["nbi_particle_source"].runtime_params.S_nbi_tot = 0.0 + source_models.sources["fusion_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + source_models.sources["ohmic_heat_source"].runtime_params.mode = ( + source_runtime_params.Mode.ZERO + ) + return source_models + + def get_sim() -> sim_lib.Sim: # This approach is currently lightweight because so many objects require # config for construction, but over time we expect to transition to most @@ -94,5 +112,6 @@ def get_sim() -> sim_lib.Sim: config=config, geo=geo, stepper_builder=linear_theta_method.LinearThetaMethod, + source_models=get_sources(), transport_model=get_transport_model(), ) diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index 693e48c4..5e3fc368 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -80,8 +80,9 @@ def __call__( consts = constants.CONSTANTS - true_ni = core_profiles_t.ni.value * dynamic_config_slice_t.nref - true_ni_face = core_profiles_t.ni.face_value() * dynamic_config_slice_t.nref + nref = dynamic_config_slice_t.numerics.nref + true_ni = core_profiles_t.ni.value * nref + true_ni_face = core_profiles_t.ni.face_value() * nref # Transient term coefficient vectors for ion heat equation # (has radial dependence through r, n) @@ -149,7 +150,6 @@ def __call__( s_face=s_face, ), source_models.build_all_zero_profiles( - dynamic_config_slice=dynamic_config_slice_t, geo=geo, source_models=self.source_models, ), diff --git a/torax/tests/test_lib/sim_test_case.py b/torax/tests/test_lib/sim_test_case.py index b273a11b..2d992e0b 100644 --- a/torax/tests/test_lib/sim_test_case.py +++ b/torax/tests/test_lib/sim_test_case.py @@ -300,6 +300,7 @@ def make_frozen_optimizer_stepper( dynamic_config_slice = config_slice.build_dynamic_config_slice( config=config, transport=transport_params, + sources=source_models.runtime_params, ) callback_builder = functools.partial( sim_lib.FrozenCoeffsCallback, diff --git a/torax/tests/test_lib/torax_refs.py b/torax/tests/test_lib/torax_refs.py index b1560d28..0c4bb77e 100644 --- a/torax/tests/test_lib/torax_refs.py +++ b/torax/tests/test_lib/torax_refs.py @@ -59,9 +59,6 @@ def circular_references() -> References: 'q_correction_factor': 1.0, }, 'nu': 3, - 'fext': 0.2, - 'wext': 0.05, - 'rext': 0.4, }, ) geo = geometry.build_circular_geometry( @@ -214,9 +211,6 @@ def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid 'q_correction_factor': 1.0, }, 'nu': 3, - 'fext': 0.2, - 'wext': 0.05, - 'rext': 0.4, }, ) geo = geometry.build_chease_geometry( @@ -370,9 +364,6 @@ def chease_references_Ip_from_config() -> References: # pylint: disable=invalid 'q_correction_factor': 1.0, }, 'nu': 3, - 'fext': 0.2, - 'wext': 0.05, - 'rext': 0.4, }, ) geo = geometry.build_chease_geometry( diff --git a/torax/transport_model/qlknn_wrapper.py b/torax/transport_model/qlknn_wrapper.py index 658deb80..331a950d 100644 --- a/torax/transport_model/qlknn_wrapper.py +++ b/torax/transport_model/qlknn_wrapper.py @@ -172,7 +172,7 @@ def from_config_slice( ) -> '_QLKNNRuntimeConfigInputs': assert isinstance(dynamic_config_slice.transport, DynamicRuntimeParams) return _QLKNNRuntimeConfigInputs( - nref=dynamic_config_slice.nref, + nref=dynamic_config_slice.numerics.nref, Ai=dynamic_config_slice.plasma_composition.Ai, Zeff=dynamic_config_slice.plasma_composition.Zeff, transport=dynamic_config_slice.transport, diff --git a/torax/transport_model/tests/qlknn_wrapper.py b/torax/transport_model/tests/qlknn_wrapper.py index 0c4bc01e..45a51a81 100644 --- a/torax/transport_model/tests/qlknn_wrapper.py +++ b/torax/transport_model/tests/qlknn_wrapper.py @@ -21,6 +21,7 @@ from torax import config_slice from torax import core_profile_setters from torax import geometry +from torax.sources import source_models as source_models_lib from torax.transport_model import qlknn_wrapper @@ -36,13 +37,18 @@ def test_qlknn_wrapper_cache_works(self): qlknn_jitted = jax.jit(qlknn) config = config_lib.Config() geo = geometry.build_circular_geometry(config) + source_models = source_models_lib.SourceModels() dynamic_config_slice = config_slice.build_dynamic_config_slice( config=config, transport=qlknn.runtime_params, + sources=source_models.runtime_params, ) static_config_slice = config_slice.build_static_config_slice(config) core_profiles = core_profile_setters.initial_core_profiles( - static_config_slice, dynamic_config_slice, geo + static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, + geo=geo, + source_models=source_models, ) qlknn_jitted(dynamic_config_slice, geo, core_profiles) # The call should be cached. If there was an error, the cache size would be diff --git a/torax/transport_model/tests/transport_model.py b/torax/transport_model/tests/transport_model.py index 0a3492b1..7ad8d140 100644 --- a/torax/transport_model/tests/transport_model.py +++ b/torax/transport_model/tests/transport_model.py @@ -48,7 +48,11 @@ def test_smoothing(self): smoothing_sigma=0.05, ) ) - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice( + config, + transport=transport_model.runtime_params, + sources=source_models.runtime_params, + ) static_config_slice = config_slice.build_static_config_slice(config) time_calculator = fixed_time_step_calculator.FixedTimeStepCalculator() input_state = sim_lib.get_initial_state(