Skip to content

Commit

Permalink
Intermediate stage testing Linear vs Nonlinear solvers, no difference…
Browse files Browse the repository at this point in the history
… in results or timings
  • Loading branch information
atb1995 committed Dec 18, 2024
1 parent 704ba3f commit 249dfd3
Showing 1 changed file with 89 additions and 1 deletion.
90 changes: 89 additions & 1 deletion gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from enum import Enum
from firedrake import (Function, Constant, NonlinearVariationalProblem,
NonlinearVariationalSolver)
NonlinearVariationalSolver, LinearVariationalProblem,
LinearVariationalSolver, TrialFunction)
from firedrake.fml import replace_subject, all_terms, drop, keep, Term
from firedrake.utils import cached_property
from firedrake.formmanipulation import split_form
Expand Down Expand Up @@ -41,6 +42,7 @@ class RungeKuttaFormulation(Enum):

increment = 1595712
predictor = 8234900
predictor_lin_solve = 1386823
linear = 269207


Expand Down Expand Up @@ -141,6 +143,11 @@ def setup(self, equation, apply_bcs=True, *active_labels):

if self.rk_formulation == RungeKuttaFormulation.predictor:
self.field_i = [Function(self.fs) for _ in range(self.nStages+1)]
elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
self.field_i = [Function(self.fs) for _ in range(self.nStages+1)]
# self.field_i_inc = [Function(self.fs) for _ in range(self.nStages+1)]
self.df_trial = TrialFunction(self.fs)
self.df = Function(self.fs)
elif self.rk_formulation == RungeKuttaFormulation.increment:
self.k = [Function(self.fs) for _ in range(self.nStages)]
elif self.rk_formulation == RungeKuttaFormulation.linear:
Expand Down Expand Up @@ -171,6 +178,22 @@ def solver(self):
)
solver_list.append(solver)
return solver_list
elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
solver_list = []

for stage in range(self.nStages):
# setup linear solver using lhs and rhs defined in derived class
problem = LinearVariationalProblem(
self.lhs[stage].form, self.rhs[stage].form,
self.df, bcs=self.bcs
)
solver_name = self.field_name+self.__class__.__name__+str(stage)
solver = LinearVariationalSolver(
problem, solver_parameters=self.solver_parameters,
options_prefix=solver_name
)
solver_list.append(solver)
return solver_list

elif self.rk_formulation == RungeKuttaFormulation.linear:
problem = NonlinearVariationalProblem(
Expand Down Expand Up @@ -222,6 +245,17 @@ def lhs(self):

return lhs_list

elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
lhs_list = []
for stage in range(self.nStages):
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.df_trial, self.idx),
map_if_false=drop)
lhs_list.append(l)

return lhs_list

if self.rk_formulation == RungeKuttaFormulation.linear:
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
Expand Down Expand Up @@ -289,6 +323,32 @@ def rhs(self):

return rhs_list

elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
rhs_list = []

for stage in range(self.nStages):
r = self.residual.label_map(
all_terms,
map_if_true=replace_subject(self.field_i[0], old_idx=self.idx))

r = r.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=lambda t: -self.butcher_matrix[stage, 0]*self.dt*t)

for i in range(1, stage+1):
r_i = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.field_i[i], old_idx=self.idx)
)

r -= self.butcher_matrix[stage, i]*self.dt*r_i

rhs_list.append(r)

return rhs_list

elif self.rk_formulation == RungeKuttaFormulation.linear:

r = self.residual.label_map(
Expand Down Expand Up @@ -391,6 +451,34 @@ def solve_stage(self, x0, stage):
if self.limiter is not None:
self.limiter.apply(self.x1)

elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
# Set initial field
if stage == 0:
self.field_i[0].assign(x0)

# Use previous stage value as a first guess (otherwise may not converge)
self.field_i[stage+1].assign(self.field_i[stage])

# Update field_i for physics / limiters
for evaluate in self.evaluate_source:
# TODO: not implemented! Here we need to evaluate the m-th term
# in the i-th RHS with field_m
raise NotImplementedError(
'Physics not implemented with RK schemes that use the '
+ 'predictor form')
if self.limiter is not None:
self.limiter.apply(self.field_i[stage])

# Obtain field_ip1 = field_n - dt* sum_m{a_im*F[field_m]}
self.solver[stage].solve()

self.field_i[stage+1].assign(self.field_i[0] + self.df)

if (stage == self.nStages - 1):
self.x1.assign(self.field_i[stage+1])
if self.limiter is not None:
self.limiter.apply(self.x1)

elif self.rk_formulation == RungeKuttaFormulation.linear:

# Set combined index of stage and subcycle
Expand Down

0 comments on commit 249dfd3

Please sign in to comment.