Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624313060
  • Loading branch information
The swirl_lm Authors authored and john-qingwang committed Apr 12, 2024
1 parent b2a8227 commit 11d4d88
Show file tree
Hide file tree
Showing 72 changed files with 3,120 additions and 1,398 deletions.
77 changes: 72 additions & 5 deletions swirl_lm/base/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from swirl_lm.core import simulation
from swirl_lm.linalg import poisson_solver
from swirl_lm.utility import common_ops
from swirl_lm.utility import stretched_grid
from swirl_lm.utility import tpu_util
from swirl_lm.utility import types
import tensorflow as tf
Expand Down Expand Up @@ -87,6 +88,7 @@
FLAGS = flags.FLAGS

CKPT_DIR_FORMAT = '{filename_prefix}-ckpts/'
COMPLETION_FILE = 'DONE'
_MAX_UVW_CFL = 'max_uvw_cfl'

Array = Any
Expand Down Expand Up @@ -137,7 +139,13 @@ def get_checkpoint_manager(
)


def _get_state_keys(params):
def _write_completion_file(output_dir: str) -> None:
"""Writes the completion file to the `output_dir`."""
with tf.io.gfile.GFile(f'{output_dir}/{COMPLETION_FILE}', 'w') as f:
f.write('')


def _get_state_keys(params: parameters_lib.SwirlLMParameters):
"""Returns essential, additional and helper var state keys."""
# Essential flow field variables:
# u: velocity in dimension 0;
Expand All @@ -154,6 +162,15 @@ def _get_state_keys(params):
params.helper_var_keys if params.helper_var_keys else []
)

# Add additional and helper_var keys required for stretched grids.
stretched_grid_additional_keys, stretched_grid_helper_var_keys = (
stretched_grid.additional_and_helper_var_keys(
params.use_stretched_grid, params.use_3d_tf_tensor
)
)
additional_keys.extend(stretched_grid_additional_keys)
helper_var_keys.extend(stretched_grid_helper_var_keys)

# Check to make sure we don't have keys duplicating / overwriting each other.
if len(set(essential_keys)) + len(set(additional_keys)) + len(
set(helper_var_keys)
Expand Down Expand Up @@ -188,6 +205,13 @@ def init_fn(
states = {}


# Add helper variables for stretched grids.
states.update(
stretched_grid.local_stretched_grid_vars_from_global_xyz(
params, coordinates
)
)

# Add helper variables from Poisson solver.
poisson_solver_helper_var_fn = (
poisson_solver.poisson_solver_helper_variable_init_fn(params)
Expand Down Expand Up @@ -238,7 +262,7 @@ def _process_at_step_id(
Args:
process_fn: Function accepting `essential_states` and `additional_states`,
and returning the updated values of individual states in a dictionary.
essential_states: The essential states, corresponds to the`states` keyword
essential_states: The essential states, corresponds to the `states` keyword
argument of `process_fn`.
additional_states: The additional states, corresponds to the
`additional_states` keyword argument of `process_fn`.
Expand Down Expand Up @@ -591,6 +615,43 @@ def solver(
considered a new simulation and will save its `state` after initialization
as a checkpoint.
Categories of tensors in `state`:
The tensors in `state` consist of 3 categories, which `driver._one_cycle()`
treats differently. There are `essential_keys`, `additional_keys`, and
`helper_var_keys`. Tensors corresponding to `essential_keys` and
`additional_keys` are unstacked from tensors into lists of tensors at the
beginning of a cycle and restacked into tensors at the end of a cycle.
Tensors corresponding to `helper_var_keys` are left as is.
When passing the tensors to the simulation model's step function, the
tensors corresponding to `essential_keys` are put into one dictionary,
whereas the tensors corresponding to `additional_keys` and `helper_var_keys`
are put into a separate dictionary.
Termination:
In the normal case, the solver will return after it has run the simulation
for the requested number of cycles. The output directory will contain data
for `num_cycles + 1` steps, more specifically, for step numbers [0, 1 *
num_steps, 2 * num_steps, ..., num_cycles * num_steps].
If the simulation reaches a state where any variable has a non-finite value
(NaN or Inf), then the simulation will stop early and save the state one
step before the state that contains non-finite values. As a result, there
will be fewer steps saved than `num_cycles + ` and the final saved step
number will not necessarily be a multiple of num_steps.
In both of the these cases (`num_cycles` reached or non-finite value seen),
the solver will write an empty `DONE` file to the output directory to
indicate that the simulation is complete.
The simulation can also terminate by raising an exception (e.g., because of
input errors, resource issues, etc.). In this case, the solver will exit
before writing an empty `DONE` file.
In case of restarts (e.g., by increasing `num_cycles` to continue a
simulation), the `DONE` file should be removed before starting the solver.
Args:
customized_init_fn: The function that initializes the flow field. The
function needs to be replica dependent.
Expand Down Expand Up @@ -836,13 +897,11 @@ def write_state_and_sync(
params.num_steps,
)

kernel_op = params.kernel_op

# Get the model that defines the concrete simulation calculation procedures.
# Since we are allowing some model object's methods to be decorated with
# `tf.function`, calling `_get_model` outside the loop ensures that these
# methods are traced only once.
model = _get_model(kernel_op, params)
model = _get_model(params.kernel_op, params)

while step_id_value() < (
params.start_step + params.num_steps * params.num_cycles
Expand Down Expand Up @@ -883,6 +942,10 @@ def write_state_and_sync(
# materialized, we can guarantee that the actual write actions are
# completed, and we update the saved completed step here.
ckpt_manager.save()
# Wait for checkpoint manager before marking completion.
ckpt_manager.sync()
# Mark simulation as complete.
_write_completion_file(output_dir)
raise ValueError(
f'Non-convergence detected. Early exit from cycle {cycle} at step '
f'{step_id_value() + 1}. The last valid state at step '
Expand Down Expand Up @@ -938,4 +1001,8 @@ def write_state_and_sync(
t2 = time.time()
logging.info('Writing output & checkpoint took %f secs.', t2 - t1)

# Wait for checkpoint manager before marking completion.
ckpt_manager.sync()
_write_completion_file(output_dir)

return strategy.experimental_local_results(state)
81 changes: 63 additions & 18 deletions swirl_lm/base/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Library for initializing the variables on TPU cores."""

import enum
from typing import Callable, List, Optional, Sequence, Text, Tuple, Union
from typing import Callable, List, Literal, Optional, Sequence, Text, Tuple, Union

import numpy as np
from swirl_lm.utility import grid_parametrization
Expand Down Expand Up @@ -50,7 +50,8 @@ def partial_mesh_for_core(
perm: Optional[ThreeIntTuple] = DEFAULT_PERMUTATION,
pad_mode: Optional[Text] = _DEFAULT_PAD_MODE,
num_boundary_points: int = DEFAULT_NUM_BOUNDARY_POINTS,
mesh_choice: MeshChoice = MeshChoice.DERIVED) -> tf.Tensor:
mesh_choice: MeshChoice = MeshChoice.DERIVED,
) -> tf.Tensor:
"""Generates a partial mesh of a given value function for a core.
The full grid spec is provided by `params`. The value function `value_fn`
Expand All @@ -68,10 +69,10 @@ def partial_mesh_for_core(
coordinate: A vector/sequence of integer with length 3 representing the
logical coordinate of the core in the logical mesh [x, y, z].
value_fn: A function that takes the local mesh_grid tensor for the core (in
order x, y, z), the global characteristic length floats (in order x, y,
z) and the local core coordinate, and returns a 3-D tensor representing
the value for the local core (without including the margin/overlap between
the cores).
order x, y, z), the global characteristic length floats (in order x, y, z)
and the local core coordinate, and returns a 3-D tensor representing the
value for the local core (without including the margin/overlap between the
cores).
perm: A 3-tuple that defines the permutation ordering for the returned
tensor. The default is (2, 0, 1). If `None`, permutation is not applied.
pad_mode: Defines the padding applied the returned tensor. Must be
Expand All @@ -92,7 +93,7 @@ def partial_mesh_for_core(
ValueError: If arguments are incorrect.
"""

def get_slice_in_dim(core_n, length, num_cores, core_id, mesh_from_params):
def get_slice_in_dim(core_n, length, num_cores, core_id, provided_mesh):
"""Returns the portion of the (sub) grid in the given dimension.
Note that on each side we pad one grid point regardless of the halo width.
Expand All @@ -104,26 +105,28 @@ def get_slice_in_dim(core_n, length, num_cores, core_id, mesh_from_params):
length: The spatial extent of the grid.
num_cores: The total number of cores.
core_id: The index of the core in {0, 1, ... num_cores - 1}.
mesh_from_params: Mesh directly obtaind from `params`.
provided_mesh: Global mesh, provided from `params`.
Returns:
The subgrid corresponding to the portion of the grid in the given
dimension assigned to the `core_id`.
"""
if not core_n:
return [_NP_DTYPE(0.0)]

if mesh_choice == MeshChoice.DERIVED:
linspace = tf.linspace(
_NP_DTYPE(0.0), _NP_DTYPE(length),
num_cores * core_n + num_boundary_points)
mesh = tf.linspace(
_NP_DTYPE(0.0),
_NP_DTYPE(length),
num_cores * core_n + num_boundary_points,
)
else:
linspace = mesh_from_params
mesh = provided_mesh

boundary_offset = num_boundary_points // 2
return linspace[core_id * core_n + boundary_offset:(core_id + 1) * core_n +
boundary_offset]
start = core_id * core_n + boundary_offset
end = start + core_n
return mesh[start:end]

lx = params.lx
ly = params.ly
Expand Down Expand Up @@ -168,9 +171,9 @@ def get_slice_in_dim(core_n, length, num_cores, core_id, mesh_from_params):
'Invalid subgrid coordinate specified with z core index. Must be '
'smaller than total number of core partitioning in z direction.')

xs = get_slice_in_dim(core_nx, lx, cx, gx, params.x)
ys = get_slice_in_dim(core_ny, ly, cy, gy, params.y)
zs = get_slice_in_dim(core_nz, lz, cz, gz, params.z)
xs = get_slice_in_dim(core_nx, lx, cx, gx, params.global_xyz[0])
ys = get_slice_in_dim(core_ny, ly, cy, gy, params.global_xyz[1])
zs = get_slice_in_dim(core_nz, lz, cz, gz, params.global_xyz[2])

xx, yy, zz = tf.meshgrid(xs, ys, zs, indexing='ij')
val = value_fn(xx, yy, zz, _NP_DTYPE(lx), _NP_DTYPE(ly), _NP_DTYPE(lz), # pytype: disable=wrong-arg-types # numpy-scalars
Expand All @@ -189,6 +192,48 @@ def get_slice_in_dim(core_n, length, num_cores, core_id, mesh_from_params):
return val


def reshape_to_broadcastable(
f_1d: tf.Tensor, dim: Literal[0, 1, 2]
) -> tf.Tensor:
"""Reshapes a rank-1 tensor to a form broadcastable against 3D fields.
This function is appropriate for initialization and storing of 1D arrays, to
be used later on in the simulation.
Note: do not use this function inside of `partial_mesh_for_core`. That
function expects dimensions to be ordered (x,y,z), whereas this function
outputs dimensions with order (z,x,y).
Here, `dim` is 0, 1, or 2, corresponding to dimension x, y, or z respectively.
The rank-1 tensor `f_1d` will be reshaped such that it represents a 3D field
whose values vary only along dimension `dim`. However, for memory efficiency,
the number of elements do not change. The output can be used in operations
with 3D fields, with broadcasting occurring.
The number of elements of `f_1d` must be correct on input (this is NOT
checked). That is, if `dim`==0, 1, or 2, then len(f_1d) must equal nx, ny, or
nz, respectively, where `nx`, `ny`, `nz` are the corresponding sizes of 3D
fields.
Args:
f_1d: A rank-1 tensor.
dim: The dimension of variation of the input tensor `f_1d`.
Returns:
The reshaped tensor that can be broadcast against a 3D field.
"""
assert (
f_1d.ndim == 1
), f'Expecting rank-1 tensor, got rank-{f_1d.ndim} tensor.'
if dim == 0:
return f_1d[tf.newaxis, :, tf.newaxis] # Set tensor shape to (1, nx, 1).
elif dim == 1:
return f_1d[tf.newaxis, tf.newaxis, :] # Set tensor shape to (1, 1, ny).
else: # dim == 2
return f_1d[:, tf.newaxis, tf.newaxis] # Set tensor shape to (nz, 1, 1).


# Below are convenience wrappers of some initialization functions.
def gen_circular_u(params, coordinate, omega=0.1, perm=DEFAULT_PERMUTATION):
"""A simple wrapper for generating circular U field."""
Expand Down
63 changes: 60 additions & 3 deletions swirl_lm/base/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import copy
import os
import os.path
from typing import Callable, List, Mapping, Optional, Sequence, Tuple
from typing import Callable, List, Literal, Mapping, Optional, Sequence, Tuple, TypeAlias

from absl import flags
from absl import logging
import numpy as np
Expand All @@ -38,6 +39,9 @@

from google.protobuf import text_format

FlowFieldVal: TypeAlias = types.FlowFieldVal
FlowFieldMap: TypeAlias = types.FlowFieldMap

# The threshold of the difference between the absolute value of the
# gravitational vector along a dimension and one. Below this threshold the
# cooresponding dimension is the gravity (vertical) dimension.
Expand Down Expand Up @@ -174,8 +178,19 @@ def __init__(
self,
config: parameters_pb2.SwirlLMParameters,
grid_params: Optional[
grid_parametrization_pb2.GridParametrization] = None,
grid_parametrization_pb2.GridParametrization
] = None,
):
"""Initializes the SwirlLMParameters object.
Args:
config: An instance of the `SwirlLMParameters` proto.
grid_params: An instance of the `GridParametrization` proto.
Raises:
ValueError: If the kernel operator type or scheme used for discretizing
the diffusion term is not recognized.
"""
super(SwirlLMParameters, self).__init__(grid_params)

self.swirl_lm_parameters_proto = config
Expand Down Expand Up @@ -273,7 +288,9 @@ def __init__(
'Gravity dimension is ambiguous if it is not aligned with an axis.'
f' {g_dim} is provided.'
)
self.g_dim = g_dim.item() if len(g_dim) == 1 else None
self.g_dim: Literal[0, 1, 2] | None = (
g_dim.item() if len(g_dim) == 1 else None
)

# Get the scalar related quantities if scalars are solved as a
# `List[SwirlLMParameters.Scalar]`.
Expand Down Expand Up @@ -398,6 +415,9 @@ def __init__(
f' DIFFUSION_SCHEME_CENTRAL_3 supports the Monin-Obukhov '
f'similarity theory.')

if any(self.use_stretched_grid):
_validate_config_for_stretched_grid(config)

# Toggle if to run with the debug mode.
self.dbg = FLAGS.simulation_debug

Expand Down Expand Up @@ -968,3 +988,40 @@ def params_from_config_file_flag() -> SwirlLMParameters:
raise ValueError('Flag --config_filepath is not set.')

return SwirlLMParameters.config_from_proto(FLAGS.config_filepath)


def _validate_config_for_stretched_grid(
config: parameters_pb2.SwirlLMParameters,
) -> None:
"""Validates the config for features available with stretched grid.
Caution: This is a list of *known* features that are not yet supported, but
the list is not necessarily exhaustive.
Args:
config: An instance of the `SwirlLMParameters` proto.
Raises:
NotImplementedError: If the config has features turned on that are not
supported by stretched grid.
"""
if config.HasField('boundary_models') and config.boundary_models.HasField(
'ib'
):
raise NotImplementedError(
'Immersed boundary method is not yet supported with stretched grid.'
)

if config.enable_rhie_chow_correction:
raise NotImplementedError(
'Rhie-Chow correction is not supported with stretched grid.'
)

if (
config.diffusion_scheme
== numerics_pb2.DiffusionScheme.DIFFUSION_SCHEME_STENCIL_3
):
raise NotImplementedError(
f'Diffusion scheme {config.diffusion_scheme} is not supported with'
' stretched grid.'
)
Loading

0 comments on commit 11d4d88

Please sign in to comment.