Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increment form for implicit RK added and tested #566

Merged
merged 19 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,9 @@ def __init__(self, domain, butcher_matrix, field_name=None,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
self.butcher_matrix = butcher_matrix
self.nbutcher = int(np.shape(self.butcher_matrix)[0])
self.nStages = int(np.shape(self.butcher_matrix)[0])
self.rk_formulation = rk_formulation

@property
def nStages(self):
return self.nbutcher

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Set up the time discretisation based on the equation.
Expand Down
177 changes: 143 additions & 34 deletions gusto/time_discretisation/implicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import numpy as np

from firedrake import (Function, split, NonlinearVariationalProblem,
NonlinearVariationalSolver)
NonlinearVariationalSolver, Constant)
from firedrake.fml import replace_subject, all_terms, drop
from firedrake.utils import cached_property

from gusto.core.labels import time_derivative
from gusto.time_discretisation.time_discretisation import (
TimeDiscretisation, wrapper_apply
)
from gusto.time_discretisation.explicit_runge_kutta import RungeKuttaFormulation


__all__ = ["ImplicitRungeKutta", "ImplicitMidpoint", "QinZhang"]
Expand Down Expand Up @@ -56,6 +57,7 @@ class ImplicitRungeKutta(TimeDiscretisation):
# ---------------------------------------------------------------------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth us making the predictor and increment forms clear in the docstrings?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made it a bit more clear, describing what we are solving for


def __init__(self, domain, butcher_matrix, field_name=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, options=None,):
"""
Args:
Expand All @@ -66,6 +68,9 @@ def __init__(self, domain, butcher_matrix, field_name=None,
discretisation.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
rk_formulation (:class:`RungeKuttaFormulation`, optional):
an enumerator object, describing the formulation of the Runge-
Kutta scheme. Defaults to the increment form.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -78,6 +83,7 @@ def __init__(self, domain, butcher_matrix, field_name=None,
options=options)
self.butcher_matrix = butcher_matrix
self.nStages = int(np.shape(self.butcher_matrix)[1])
self.rk_formulation = rk_formulation

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Expand All @@ -91,31 +97,108 @@ def setup(self, equation, apply_bcs=True, *active_labels):

super().setup(equation, apply_bcs, *active_labels)

self.k = [Function(self.fs) for i in range(self.nStages)]
if self.rk_formulation == RungeKuttaFormulation.predictor:
self.xs = [Function(self.fs) for _ in range(self.nStages)]
elif self.rk_formulation == RungeKuttaFormulation.increment:
self.k = [Function(self.fs) for _ in range(self.nStages)]
elif self.rk_formulation == RungeKuttaFormulation.linear:
raise NotImplementedError(
'Linear Implicit Runge-Kutta formulation is not implemented'
)
else:
raise NotImplementedError(
'Runge-Kutta formulation is not implemented'
)

def lhs(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know we don't set up the lhs and rhs variables? If these are inherited from the base TimeDiscretisation class, are they wrong? They don't appear to be used anywhere either. I wonder if we should either:

  • use them
  • set their values to be None
  • remove them and stop them being an @abstractproperty

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed lhs & rhs and stopped them being an abstract property

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead they are now just a property

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this change. @jshipton are you happy with this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now removed all lhs & rhs. Each time discretisation just has a res (residual).

return super().lhs

def rhs(self):
return super().rhs

def solver(self, stage):
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.xnph, self.idx),
)
def res(self, stage):
"""Set up the discretisation's residual for a given stage."""
atb1995 marked this conversation as resolved.
Show resolved Hide resolved
# Add time derivative terms y_s - y^n for stage s
mass_form = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual += mass_form.label_map(all_terms,
replace_subject(self.x_out, self.idx))
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calcualte/sum
atb1995 marked this conversation as resolved.
Show resolved Hide resolved
# dt*(a_s1*F(y_1) + a_s2*F(y_2)+ ... + a_{s,s-1}*F(y_{s-1}))
for i in range(stage):
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.xs[i], old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[stage, i])*self.dt*t)
residual += r_imp
# Calculate and add on dt*a_ss*F(y_s)
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.x_out, old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[stage, stage])*self.dt*t)
residual += r_imp
return residual.form

@property
def final_res(self):
"""Set up the discretisation's final residual."""
# Add time derivative terms y^{n+1} - y^n
atb1995 marked this conversation as resolved.
Show resolved Hide resolved
mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calcualte/sum
atb1995 marked this conversation as resolved.
Show resolved Hide resolved
# dt*(b_1*F(y_1) + b_2*F(y_2) + .... + b_s*F(y_s))
for i in range(self.nStages):
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.xs[i], old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[self.nStages, i])*self.dt*t)
residual += r_imp
return residual.form

problem = NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)
def solver(self, stage):
if self.rk_formulation == RungeKuttaFormulation.increment:
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.xnph, self.idx),
)
mass_form = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual += mass_form.label_map(all_terms,
replace_subject(self.x_out, self.idx))

problem = NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)

elif self.rk_formulation == RungeKuttaFormulation.predictor:
problem = NonlinearVariationalProblem(self.res(stage), self.x_out, bcs=self.bcs)

solver_name = self.field_name+self.__class__.__name__ + "%s" % (stage)
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters,
options_prefix=solver_name)
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)

@cached_property
def final_solver(self):
"""Set up a solver for the final solve to evaluate time level n+1."""
# setup solver using lhs and rhs defined in derived class
problem = NonlinearVariationalProblem(self.final_res, self.x_out, bcs=self.bcs)
solver_name = self.field_name+self.__class__.__name__
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)

@cached_property
def solvers(self):
Expand All @@ -126,32 +209,48 @@ def solvers(self):

def solve_stage(self, x0, stage):
self.x1.assign(x0)
for i in range(stage):
self.x1.assign(self.x1 + self.butcher_matrix[stage, i]*self.dt*self.k[i])
if self.rk_formulation == RungeKuttaFormulation.increment:
for i in range(stage):
self.x1.assign(self.x1 + self.butcher_matrix[stage, i]*self.dt*self.k[i])

if self.idx is None and len(self.fs) > 1:
self.xnph = tuple([self.dt*self.butcher_matrix[stage, stage]*a + b
for a, b in zip(split(self.x_out), split(self.x1))])
else:
self.xnph = self.x1 + self.butcher_matrix[stage, stage]*self.dt*self.x_out
solver = self.solvers[stage]
# Set initial guess for solver
if (stage > 0):
self.x_out.assign(self.k[stage-1])
if self.idx is None and len(self.fs) > 1:
self.xnph = tuple(
self.dt * self.butcher_matrix[stage, stage] * a + b
for a, b in zip(split(self.x_out), split(self.x1))
)
else:
self.xnph = self.x1 + self.butcher_matrix[stage, stage]*self.dt*self.x_out

solver = self.solvers[stage]

solver.solve()
# Set initial guess for solver
if (stage > 0):
self.x_out.assign(self.k[stage-1])

self.k[stage].assign(self.x_out)
solver.solve()
self.k[stage].assign(self.x_out)

elif self.rk_formulation == RungeKuttaFormulation.predictor:
if (stage > 0):
self.x_out.assign(self.xs[stage-1])
solver = self.solvers[stage]
solver.solve()

self.xs[stage].assign(self.x_out)

@wrapper_apply
def apply(self, x_out, x_in):

self.x_out.assign(x_in)
for i in range(self.nStages):
self.solve_stage(x_in, i)

x_out.assign(x_in)
for i in range(self.nStages):
x_out.assign(x_out + self.butcher_matrix[self.nStages, i]*self.dt*self.k[i])
if self.rk_formulation == RungeKuttaFormulation.increment:
x_out.assign(x_in)
for i in range(self.nStages):
x_out.assign(x_out + self.butcher_matrix[self.nStages, i]*self.dt*self.k[i])
elif self.rk_formulation == RungeKuttaFormulation.predictor:
self.final_solver.solve()
x_out.assign(self.x_out)


class ImplicitMidpoint(ImplicitRungeKutta):
Expand All @@ -164,14 +263,18 @@ class ImplicitMidpoint(ImplicitRungeKutta):
k0 = F[y^n + 0.5*dt*k0] \n
y^(n+1) = y^n + dt*k0 \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
options=None):
def __init__(self, domain, field_name=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
rk_formulation (:class:`RungeKuttaFormulation`, optional):
an enumerator object, describing the formulation of the Runge-
Kutta scheme. Defaults to the increment form.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -181,6 +284,7 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
"""
butcher_matrix = np.array([[0.5], [1.]])
super().__init__(domain, butcher_matrix, field_name,
rk_formulation=rk_formulation,
solver_parameters=solver_parameters,
options=options)

Expand All @@ -196,14 +300,18 @@ class QinZhang(ImplicitRungeKutta):
k1 = F[y^n + 0.5*dt*k0 + 0.25*dt*k1] \n
y^(n+1) = y^n + 0.5*dt*(k0 + k1) \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
options=None):
def __init__(self, domain, field_name=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
rk_formulation (:class:`RungeKuttaFormulation`, optional):
an enumerator object, describing the formulation of the Runge-
Kutta scheme. Defaults to the increment form.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -213,5 +321,6 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
"""
butcher_matrix = np.array([[0.25, 0], [0.5, 0.25], [0.5, 0.5]])
super().__init__(domain, butcher_matrix, field_name,
rk_formulation=rk_formulation,
solver_parameters=solver_parameters,
options=options)
11 changes: 7 additions & 4 deletions integration-tests/model/test_time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ def run(timestepper, tmax, f_end):

@pytest.mark.parametrize(
"scheme", [
"ssprk3_increment", "TrapeziumRule", "ImplicitMidpoint", "QinZhang",
"ssprk3_increment", "TrapeziumRule", "ImplicitMidpoint",
"QinZhang_increment", "QinZhang_predictor",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth", "Leapfrog",
"AdamsMoulton", "AdamsMoulton", "ssprk3_predictor", "ssprk3_linear"
"AdamsMoulton", "ssprk3_predictor", "ssprk3_linear"
]
)
def test_time_discretisation(tmpdir, scheme, tracer_setup):
Expand Down Expand Up @@ -40,8 +41,10 @@ def test_time_discretisation(tmpdir, scheme, tracer_setup):
transport_scheme = TrapeziumRule(domain)
elif scheme == "ImplicitMidpoint":
transport_scheme = ImplicitMidpoint(domain)
elif scheme == "QinZhang":
transport_scheme = QinZhang(domain)
elif scheme == "QinZhang_increment":
transport_scheme = QinZhang(domain, rk_formulation=RungeKuttaFormulation.increment)
elif scheme == "QinZhang_predictor":
transport_scheme = QinZhang(domain, rk_formulation=RungeKuttaFormulation.predictor)
elif scheme == "RK4":
transport_scheme = RK4(domain)
elif scheme == "Heun":
Expand Down
Loading