Skip to content

Commit

Permalink
Revert "Intermediate stage testing Linear vs Nonlinear solvers, no di…
Browse files Browse the repository at this point in the history
…fference in results or timings"

This reverts commit 249dfd3.
  • Loading branch information
atb1995 committed Dec 18, 2024
1 parent 249dfd3 commit ff9f524
Showing 1 changed file with 1 addition and 89 deletions.
90 changes: 1 addition & 89 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from enum import Enum
from firedrake import (Function, Constant, NonlinearVariationalProblem,
NonlinearVariationalSolver, LinearVariationalProblem,
LinearVariationalSolver, TrialFunction)
NonlinearVariationalSolver)
from firedrake.fml import replace_subject, all_terms, drop, keep, Term
from firedrake.utils import cached_property
from firedrake.formmanipulation import split_form
Expand Down Expand Up @@ -42,7 +41,6 @@ class RungeKuttaFormulation(Enum):

increment = 1595712
predictor = 8234900
predictor_lin_solve = 1386823
linear = 269207


Expand Down Expand Up @@ -143,11 +141,6 @@ def setup(self, equation, apply_bcs=True, *active_labels):

if self.rk_formulation == RungeKuttaFormulation.predictor:
self.field_i = [Function(self.fs) for _ in range(self.nStages+1)]
elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
self.field_i = [Function(self.fs) for _ in range(self.nStages+1)]
# self.field_i_inc = [Function(self.fs) for _ in range(self.nStages+1)]
self.df_trial = TrialFunction(self.fs)
self.df = Function(self.fs)
elif self.rk_formulation == RungeKuttaFormulation.increment:
self.k = [Function(self.fs) for _ in range(self.nStages)]
elif self.rk_formulation == RungeKuttaFormulation.linear:
Expand Down Expand Up @@ -178,22 +171,6 @@ def solver(self):
)
solver_list.append(solver)
return solver_list
elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
solver_list = []

for stage in range(self.nStages):
# setup linear solver using lhs and rhs defined in derived class
problem = LinearVariationalProblem(
self.lhs[stage].form, self.rhs[stage].form,
self.df, bcs=self.bcs
)
solver_name = self.field_name+self.__class__.__name__+str(stage)
solver = LinearVariationalSolver(
problem, solver_parameters=self.solver_parameters,
options_prefix=solver_name
)
solver_list.append(solver)
return solver_list

elif self.rk_formulation == RungeKuttaFormulation.linear:
problem = NonlinearVariationalProblem(
Expand Down Expand Up @@ -245,17 +222,6 @@ def lhs(self):

return lhs_list

elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
lhs_list = []
for stage in range(self.nStages):
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.df_trial, self.idx),
map_if_false=drop)
lhs_list.append(l)

return lhs_list

if self.rk_formulation == RungeKuttaFormulation.linear:
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
Expand Down Expand Up @@ -323,32 +289,6 @@ def rhs(self):

return rhs_list

elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
rhs_list = []

for stage in range(self.nStages):
r = self.residual.label_map(
all_terms,
map_if_true=replace_subject(self.field_i[0], old_idx=self.idx))

r = r.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=lambda t: -self.butcher_matrix[stage, 0]*self.dt*t)

for i in range(1, stage+1):
r_i = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.field_i[i], old_idx=self.idx)
)

r -= self.butcher_matrix[stage, i]*self.dt*r_i

rhs_list.append(r)

return rhs_list

elif self.rk_formulation == RungeKuttaFormulation.linear:

r = self.residual.label_map(
Expand Down Expand Up @@ -451,34 +391,6 @@ def solve_stage(self, x0, stage):
if self.limiter is not None:
self.limiter.apply(self.x1)

elif self.rk_formulation == RungeKuttaFormulation.predictor_lin_solve:
# Set initial field
if stage == 0:
self.field_i[0].assign(x0)

# Use previous stage value as a first guess (otherwise may not converge)
self.field_i[stage+1].assign(self.field_i[stage])

# Update field_i for physics / limiters
for evaluate in self.evaluate_source:
# TODO: not implemented! Here we need to evaluate the m-th term
# in the i-th RHS with field_m
raise NotImplementedError(
'Physics not implemented with RK schemes that use the '
+ 'predictor form')
if self.limiter is not None:
self.limiter.apply(self.field_i[stage])

# Obtain field_ip1 = field_n - dt* sum_m{a_im*F[field_m]}
self.solver[stage].solve()

self.field_i[stage+1].assign(self.field_i[0] + self.df)

if (stage == self.nStages - 1):
self.x1.assign(self.field_i[stage+1])
if self.limiter is not None:
self.limiter.apply(self.x1)

elif self.rk_formulation == RungeKuttaFormulation.linear:

# Set combined index of stage and subcycle
Expand Down

0 comments on commit ff9f524

Please sign in to comment.