Skip to content

Commit

Permalink
Rename Qualikiz and QLKNN wrapper files to transport_model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698844052
  • Loading branch information
hamelphi authored and Torax team committed Nov 21, 2024
1 parent 4e91dbd commit 5aa54a9
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 41 deletions.
6 changes: 3 additions & 3 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torax.config import build_sim
from torax.config import config_loader
from torax.plotting import plotruns_lib
from torax.transport_model import qlknn_wrapper
from torax.transport_model import qlknn_transport_model


# String used when prompting the user to make a choice of command
Expand Down Expand Up @@ -104,9 +104,9 @@
'Path to the qlknn model network parameters (if using a QLKNN transport'
' model). If not set, then it will use the value from the config in the'
' "model_path" field in the qlknn_params. If that is not set, it will look'
f' for the "{qlknn_wrapper.MODEL_PATH_ENV_VAR}" env variable.'
f' for the "{qlknn_transport_model.MODEL_PATH_ENV_VAR}" env variable.'
' Finally, if this is also not set, it uses a hardcoded default path'
f' "{qlknn_wrapper.DEFAULT_MODEL_PATH}".',
f' "{qlknn_transport_model.DEFAULT_MODEL_PATH}".',
)

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
18 changes: 9 additions & 9 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
from torax.transport_model import bohm_gyrobohm as bohm_gyrobohm_transport
from torax.transport_model import constant as constant_transport
from torax.transport_model import critical_gradient as critical_gradient_transport
from torax.transport_model import qlknn_wrapper
from torax.transport_model import qlknn_transport_model
# pylint: disable=g-import-not-at-top
try:
from torax.transport_model import qualikiz_wrapper
from torax.transport_model import qualikiz_transport_model
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = True
except ImportError:
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = False
Expand Down Expand Up @@ -442,8 +442,8 @@ def build_transport_model_builder_from_config(
- `qlknn`: QLKNN transport.
- See `transport_model.qlknn_wrapper.RuntimeParams` for model-specific
params.
- See `transport_model.qlknn_transport_model.RuntimeParams` for
model-specific params.
- `constant`: Constant transport
Expand Down Expand Up @@ -513,16 +513,16 @@ def build_transport_model_builder_from_config(
if 'model_path' in qlknn_params:
model_path = qlknn_params.pop('model_path')
else:
model_path = qlknn_wrapper.get_default_model_path()
model_path = qlknn_transport_model.get_default_model_path()
qlknn_params.update(transport_config)
# Remove params from the other models, if present.
qlknn_params.pop('constant_params', None)
qlknn_params.pop('cgm_params', None)
qlknn_params.pop('bohm-gyrobohm_params', None)
qlknn_params.pop('qualikiz_params', None)
return qlknn_wrapper.QLKNNTransportModelBuilder(
return qlknn_transport_model.QLKNNTransportModelBuilder(
runtime_params=config_args.recursive_replace(
qlknn_wrapper.get_default_runtime_params_from_model_path(
qlknn_transport_model.get_default_runtime_params_from_model_path(
model_path
),
**qlknn_params,
Expand Down Expand Up @@ -591,9 +591,9 @@ def build_transport_model_builder_from_config(
qualikiz_params.pop('constant_params', None)
qualikiz_params.pop('bohm-gyrobohm_params', None)
# pylint: disable=undefined-variable
return qualikiz_wrapper.QualikizTransportModelBuilder(
return qualikiz_transport_model.QualikizTransportModelBuilder(
runtime_params=config_args.recursive_replace(
qualikiz_wrapper.RuntimeParams(),
qualikiz_transport_model.RuntimeParams(),
**qualikiz_params,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""A wrapper around qlknn_10d.
The wrapper calls the pretrained models trained on QuaLiKiz heat and
particle transport. The wrapper calculates qlknn_10d inputs, infers the
model, carries out post-processing, and returns a CoreTransport object
with turbulent transport coefficients.
"""
"""A transport model that uses a QLKNN model."""

from __future__ import annotations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""A wrapper around QuaLiKiz.
"""A transport model that calls QuaLiKiz.
The wrapper calls QuaLiKiz itself. Must be run with
TORAX_COMPILATION_ENABLED=False. Used for generating ground truth for QLKNN11D
evaluation. Kept as an internal model.
Must be run with TORAX_COMPILATION_ENABLED=False. Used for generating ground
truth for surrogate model evaluations.
"""

from __future__ import annotations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for torax.transport_model.qlknn_wrapper."""
"""Unit tests for torax.transport_model.qlknn_transport_model."""

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -23,18 +23,18 @@
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice
from torax.sources import source_models as source_models_lib
from torax.transport_model import qlknn_wrapper
from torax.transport_model import qlknn_transport_model


class QlknnWrapperTest(parameterized.TestCase):
"""Unit tests for the `torax.transport_model.qlknn_wrapper` module."""
class QlknnTransportModelTest(parameterized.TestCase):
"""Unit tests for the `torax.transport_model.qlknn_transport_model` module."""

def test_qlknn_wrapper_cache_works(self):
def test_qlknn_transport_model_cache_works(self):
"""Tests that QLKNN calls are properly cached."""
# This test can uncover and changes to the data structures which break the
# caching.
qlknn = qlknn_wrapper.QLKNNTransportModel(
qlknn_wrapper.get_default_model_path()
qlknn = qlknn_transport_model.QLKNNTransportModel(
qlknn_transport_model.get_default_model_path()
)
runtime_params = general_runtime_params.GeneralRuntimeParams()
geo = geometry.build_circular_geometry()
Expand All @@ -43,7 +43,7 @@ def test_qlknn_wrapper_cache_works(self):
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
transport=qlknn_wrapper.RuntimeParams(),
transport=qlknn_transport_model.RuntimeParams(),
sources=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)(
Expand Down Expand Up @@ -71,16 +71,16 @@ def test_qlknn_wrapper_cache_works(self):
def test_hash_and_eq(self):
# Test that hash and eq are invariant to copying, so that they will work
# correctly with jax's persistent cache
qlknn_1 = qlknn_wrapper.QLKNNTransportModel('foo')
qlknn_2 = qlknn_wrapper.QLKNNTransportModel('foo')
qlknn_1 = qlknn_transport_model.QLKNNTransportModel('foo')
qlknn_2 = qlknn_transport_model.QLKNNTransportModel('foo')
self.assertEqual(qlknn_1, qlknn_2)
self.assertEqual(hash(qlknn_1), hash(qlknn_2))
mock_persistent_jax_cache = set([qlknn_1])
self.assertIn(qlknn_2, mock_persistent_jax_cache)

def test_hash_and_eq_different(self):
qlknn_1 = qlknn_wrapper.QLKNNTransportModel('foo')
qlknn_2 = qlknn_wrapper.QLKNNTransportModel('bar')
qlknn_1 = qlknn_transport_model.QLKNNTransportModel('foo')
qlknn_2 = qlknn_transport_model.QLKNNTransportModel('bar')
self.assertNotEqual(qlknn_1, qlknn_2)
self.assertNotEqual(hash(qlknn_1), hash(qlknn_2))
mock_persistent_jax_cache = set([qlknn_1])
Expand All @@ -102,7 +102,7 @@ def test_filter_model_output(self, include_dict):
model_output = dict(
[(k, jnp.ones(shape)) for k in itg_keys + tem_keys + etg_keys]
)
filtered_model_output = qlknn_wrapper.filter_model_output(
filtered_model_output = qlknn_transport_model.filter_model_output(
model_output=model_output,
include_ITG=include_dict.get('itg', True),
include_TEM=include_dict.get('tem', True),
Expand Down Expand Up @@ -143,15 +143,15 @@ def test_clip_inputs(self):
[1.0, 2.625, 2.375, 12.6, 2.85, 6.0, 7.0, 8.0, 9.0],
[1.0, 2.8, 2.0, 12.6, 2.85, 6.0, 7.0, 8.0, 9.0],
])
clipped_feature_scan = qlknn_wrapper.clip_inputs(
clipped_feature_scan = qlknn_transport_model.clip_inputs(
feature_scan=feature_scan,
inputs_and_ranges=inputs_and_ranges,
clip_margin=clip_margin,
)
npt.assert_allclose(clipped_feature_scan, expected)

def test_runtime_params_builds_dynamic_params(self):
runtime_params = qlknn_wrapper.RuntimeParams()
runtime_params = qlknn_transport_model.RuntimeParams()
geo = geometry.build_circular_geometry()
provider = runtime_params.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for qualikiz wrapper transport model."""
"""Tests for qualikiz transport_model transport model."""
from absl.testing import absltest
from torax import geometry
# pylint: disable=g-import-not-at-top
try:
from torax.transport_model import qualikiz_wrapper
from torax.transport_model import qualikiz_transport_model
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = True
except ImportError:
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = False
Expand All @@ -28,7 +28,7 @@ class RuntimeParamsTest(absltest.TestCase):
def test_runtime_params_builds_dynamic_params(self):
if not _QUALIKIZ_TRANSPORT_MODEL_AVAILABLE:
self.skipTest('Qualikiz transport model is not available.')
runtime_params = qualikiz_wrapper.RuntimeParams()
runtime_params = qualikiz_transport_model.RuntimeParams()
geo = geometry.build_circular_geometry()
provider = runtime_params.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)
Expand Down

0 comments on commit 5aa54a9

Please sign in to comment.