Skip to content

Commit

Permalink
Change the API and runtime configuration for all sources.
Browse files Browse the repository at this point in the history
This moves the source runtime parameter configuration out of `Config` and into the individual sources. This also cleans up some old issues with the Source API around naming.

PiperOrigin-RevId: 628214715
  • Loading branch information
araju authored and Torax team committed Apr 25, 2024
1 parent bf40b44 commit 3b63368
Show file tree
Hide file tree
Showing 101 changed files with 3,620 additions and 2,618 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ jobs:
torax/sources/tests/generic_ion_el_heat_source.py \
torax/sources/tests/ion_el_heat_sources.py \
torax/sources/tests/qei_source.py \
torax/sources/tests/source_config.py \
torax/sources/tests/source_models.py \
torax/sources/tests/source.py \
torax/spectators/tests/plotting.py \
Expand Down
14 changes: 10 additions & 4 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def change_config(
new_config = config_module.get_config()
new_geo = config_module.get_geometry(new_config)
new_transport_model = config_module.get_transport_model()
source_models = config_module.get_sources()
new_source_params = {
name: source.runtime_params
for name, source in source_models.sources.items()
}
# Make sure the transport model has not changed.
# TODO(b/330172917): Improve the check for updated configs.
if not isinstance(new_transport_model, type(sim.transport_model)):
Expand All @@ -210,10 +215,11 @@ def change_config(
' this option, you cannot change the transport model.'
)
sim = simulation_app.update_sim(
sim,
new_config,
new_geo,
new_transport_model.runtime_params,
sim=sim,
config=new_config,
geo=new_geo,
transport_runtime_params=new_transport_model.runtime_params,
source_runtime_params=new_source_params,
)
return sim, new_config, config_module_str

Expand Down
3 changes: 2 additions & 1 deletion torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,19 @@
'dynamic_config_slice_t',
'dynamic_config_slice_t_plus_dt',
'unused_config',
'dynamic_source_runtime_params',
'geo',
'x_old',
'state',
'unused_state',
'core_profiles',
'psi',
'transport_model',
'time_step_calculator',
'source_profiles',
'source_profile',
'explicit_source_profiles',
'source_models',
'time_step_calculator',
'coeffs_callback',
'evolving_names',
'spectator',
Expand Down
44 changes: 30 additions & 14 deletions torax/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def calculate_pereverzev_flux(
"""Adds Pereverzev-Corrigan flux to diffusion terms."""

consts = constants.CONSTANTS
true_ne_face = core_profiles.ne.face_value() * dynamic_config_slice.nref
true_ni_face = core_profiles.ni.face_value() * dynamic_config_slice.nref
true_ne_face = (
core_profiles.ne.face_value() * dynamic_config_slice.numerics.nref
)
true_ni_face = (
core_profiles.ni.face_value() * dynamic_config_slice.numerics.nref
)

geo_factor = jnp.concatenate(
[jnp.ones(1), geo.g1_over_vpr_face[1:] / geo.g0_face[1:]]
Expand Down Expand Up @@ -277,22 +281,22 @@ def _calc_coeffs_full(
# Decide which values to use depending on whether the source is explicit or
# implicit.
sigma = jax_utils.select(
dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit,
dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit,
explicit_source_profiles.j_bootstrap.sigma,
implicit_source_profiles.j_bootstrap.sigma,
)
j_bootstrap = jax_utils.select(
dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit,
dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit,
explicit_source_profiles.j_bootstrap.j_bootstrap,
implicit_source_profiles.j_bootstrap.j_bootstrap,
)
j_bootstrap_face = jax_utils.select(
dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit,
dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit,
explicit_source_profiles.j_bootstrap.j_bootstrap_face,
implicit_source_profiles.j_bootstrap.j_bootstrap_face,
)
I_bootstrap = jax_utils.select( # pylint: disable=invalid-name
dynamic_config_slice.sources[source_models.j_bootstrap.name].is_explicit,
dynamic_config_slice.sources[source_models.j_bootstrap_name].is_explicit,
explicit_source_profiles.j_bootstrap.I_bootstrap,
implicit_source_profiles.j_bootstrap.I_bootstrap,
)
Expand Down Expand Up @@ -328,13 +332,21 @@ def _calc_coeffs_full(
source_models,
)

true_ne_face = core_profiles.ne.face_value() * dynamic_config_slice.nref
true_ni_face = core_profiles.ni.face_value() * dynamic_config_slice.nref
true_ne_face = (
core_profiles.ne.face_value() * dynamic_config_slice.numerics.nref
)
true_ni_face = (
core_profiles.ni.face_value() * dynamic_config_slice.numerics.nref
)

# Transient term coefficient vector (has radial dependence through r, n)
toc_temp_ion = 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.nref
toc_temp_ion = (
1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.numerics.nref
)
tic_temp_ion = core_profiles.ni.value
toc_temp_el = 1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.nref
toc_temp_el = (
1.5 * geo.vpr * consts.keV2J * dynamic_config_slice.numerics.nref
)
tic_temp_el = core_profiles.ne.value
toc_psi = (
1.0
Expand Down Expand Up @@ -516,7 +528,7 @@ def _calc_coeffs_full(
dynamic_config_slice.profile_conditions.Ip
/ (jnp.pi * geo.Rmin**2)
* 1e20
/ dynamic_config_slice.nref
/ dynamic_config_slice.numerics.nref
)
# pylint: enable=invalid-name
neped_unnorm = jnp.where(
Expand Down Expand Up @@ -551,7 +563,9 @@ def _calc_coeffs_full(
) = jax.lax.cond(
use_pereverzev,
lambda: calculate_pereverzev_flux(
dynamic_config_slice, geo, core_profiles,
dynamic_config_slice,
geo,
core_profiles,
),
lambda: tuple([jnp.zeros_like(geo.r_face)] * 6),
)
Expand All @@ -563,9 +577,11 @@ def _calc_coeffs_full(

# Ion and electron heat sources.
qei = source_models.qei_source.get_qei(
dynamic_config_slice.sources[source_models.qei_source.name].source_type,
dynamic_config_slice=dynamic_config_slice,
static_config_slice=static_config_slice,
dynamic_config_slice=dynamic_config_slice,
dynamic_source_runtime_params=dynamic_config_slice.sources[
source_models.qei_source_name
],
geo=geo,
# For Qei, always use the current set of core profiles.
# In the linear solver, core_profiles is the set of profiles at time t (at
Expand Down
71 changes: 5 additions & 66 deletions torax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
parameters.
"""

from collections.abc import Iterable, Mapping
from collections.abc import Iterable
import dataclasses
import enum
import typing
from typing import Any
import chex
from torax import interpolated_param
from torax.sources import source_config


# Type-alias for clarity. While the InterpolatedParams can vary across any
Expand Down Expand Up @@ -205,10 +204,10 @@ class Numerics:
# 1/multiplication factor for sigma (conductivity) to reduce current
# diffusion timescale to be closer to heat diffusion timescale
resistivity_mult: TimeDependentField = 1.0
# Multiplication factor for bootstrap current
bootstrap_mult: float = 1.0
# multiplier for ion-electron heat exchange term for sensitivity testing
Qei_mult: float = 1.0

# density profile info
# Reference value for normalization
nref: float = 1e20

# numerical (e.g. no. of grid points, other info needed by solver)
# radial grid points (num cells)
Expand All @@ -234,43 +233,6 @@ class Config:
)
numerics: Numerics = dataclasses.field(default_factory=Numerics)

# TODO(b/330172917): Move the source parameters into `sources`.

# density profile info
# Reference value for normalization
nref: float = 1e20

# external heat source parameters
w: TimeDependentField = 0.25 # Gaussian width in normalized radial coordinate
# Source Gaussian central location (in normalized r)
rsource: TimeDependentField = 0.0
Ptot: TimeDependentField = 120e6 # total heating
el_heat_fraction: TimeDependentField = 0.66666 # electron heating fraction

# particle source parameters
# Gaussian width of pellet deposition [normalized radial coord],
# (continuous pellet model)
pellet_width: TimeDependentField = 0.1
# Pellet source Gaussian central location [normalized radial coord]
# (continuous pellet model)
pellet_deposition_location: TimeDependentField = 0.85
# total pellet particles/s (continuous pellet model)
# TODO(b/326578331): improve numerical strategy, avoid these large numbers
S_pellet_tot: TimeDependentField = 2e22

# exponential decay length of gas puff ionization [normalized radial coord]
puff_decay_length: TimeDependentField = 0.05
# total gas puff particles/s
# TODO(b/326578331): improve numerical strategy, avoid these large numbers
S_puff_tot: TimeDependentField = 1e22

# NBI particle source Gaussian width in normalized radial coord
nbi_particle_width: TimeDependentField = 0.25
# NBI particle source Gaussian central location in normalized radial coord
nbi_deposition_location: TimeDependentField = 0.0
# NBI total particle source
S_nbi_tot: TimeDependentField = 1e22

# current profiles (broad "Ohmic" + localized "external" currents)
# peaking factor of "Ohmic" current: johm = j0*(1 - r^2/a^2)^nu
nu: float = 3.0
Expand All @@ -281,16 +243,6 @@ class Config:
# or from the psi available in the numerical geometry file. This setting is
# ignored for the ad-hoc circular geometry, which has no numerical geometry.
initial_psi_from_j: bool = False
# toggles if external current is provided absolutely or as a fraction of Ip
use_absolute_jext: bool = False
# total "external" current in MA. Used if use_absolute_jext=True.
Iext: TimeDependentField = 3.0
# total "external" current fraction. Used if use_absolute_jext=False.
fext: TimeDependentField = 0.2
# width of "external" Gaussian current profile
wext: TimeDependentField = 0.05
# normalized radius of "external" Gaussian current profile
rext: TimeDependentField = 0.4

# solver parameters
solver: SolverConfig = dataclasses.field(default_factory=SolverConfig)
Expand All @@ -301,13 +253,6 @@ class Config:

# pylint: enable=invalid-name

# Runtime configs for all source/sink terms.
# Note that the sources field is overridden in the __post_init__. See impl for
# details on how this field is updated.
sources: Mapping[str, source_config.SourceConfig] = dataclasses.field(
default_factory=source_config.get_default_sources_config
)

def sanity_check(self) -> None:
"""Checks that various configuration parameters are valid."""
# TODO(b/330172917) do more extensive config parameter sanity checking
Expand All @@ -319,12 +264,6 @@ def sanity_check(self) -> None:
assert isinstance(self.numerics, Numerics)

def __post_init__(self):
# The sources config should have the default values from
# source_config.get_default_sources_config. The additional values provided
# via the config constructor should OVERRIDE these defaults.
sources = dict(source_config.get_default_sources_config())
sources.update(self.sources) # Update with the user inputs.
self.sources = sources
self.sanity_check()


Expand Down
Loading

0 comments on commit 3b63368

Please sign in to comment.