Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint #592

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
eb360ac
make dt and t Functions on the Real space
jshipton Sep 12, 2024
93ad1cf
sort out value of t in timestepper
jshipton Sep 12, 2024
b54f3bd
more Functions in R instead of Constants
jshipton Sep 12, 2024
998deb6
adjoint diffusion example
jshipton Sep 12, 2024
209d056
try adjoint with shallow water
jshipton Sep 12, 2024
1682f88
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Sep 12, 2024
e49b006
use SIQN for adjoint shallow water
jshipton Sep 12, 2024
2b38eea
some small changes to enable moist thermal shallow water adjoint
jshipton Sep 12, 2024
b0055c8
Merge branch 'main' into adjoint
jshipton Dec 16, 2024
a7e6de1
Add adjoint tests
Ig-dolci Dec 16, 2024
3bc9a15
flake8
Ig-dolci Dec 16, 2024
4cfd4bb
wip
Ig-dolci Dec 16, 2024
7b9d24d
flake8
Ig-dolci Dec 16, 2024
31d5c59
Testing
Ig-dolci Dec 17, 2024
96becfe
Minor changer
Ig-dolci Dec 17, 2024
9f9b8eb
Test all controls
Ig-dolci Dec 17, 2024
61ca2dc
Check the blocks are empty
Ig-dolci Dec 17, 2024
6da7518
Add a notebook
Ig-dolci Dec 17, 2024
14371fd
Small changes
Ig-dolci Dec 17, 2024
9b8b0ed
dd
Ig-dolci Dec 18, 2024
b088504
Remove adjoint examples; enhance the notebook text; fix the tests
Ig-dolci Dec 18, 2024
7680ebc
flake8
Ig-dolci Dec 18, 2024
af8908b
Merge branch 'main' into adjoint
Ig-dolci Dec 18, 2024
c6eb3b2
Match with the main branch
Ig-dolci Dec 18, 2024
719a194
wip
Ig-dolci Dec 19, 2024
542d33e
Add convert_parameters_to_real_space function
Ig-dolci Dec 19, 2024
efc7e99
wip
Ig-dolci Dec 19, 2024
f89173a
flake8
Ig-dolci Dec 19, 2024
84814ce
replace deprecated decorator
Ig-dolci Dec 19, 2024
06de329
fix error
Ig-dolci Jan 6, 2025
145f111
solve conflict
Ig-dolci Jan 6, 2025
ba4fc95
more fixes
Ig-dolci Jan 6, 2025
52ed722
wip
Ig-dolci Jan 6, 2025
0601968
Revert "solve conflict"
Ig-dolci Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
768 changes: 768 additions & 0 deletions docs/notebook/shallow_water_adjoint.ipynb

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions gusto/core/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Some simple tools for configuring the model."""
from abc import ABCMeta, abstractproperty
from enum import Enum
from firedrake import sqrt, Constant
from firedrake import sqrt


__all__ = [
Expand Down Expand Up @@ -80,7 +80,11 @@ def __setattr__(self, name, value):
# Almost all parameters should be Constants -- but there are some
# specific exceptions which should be kept as integers
if type(value) in [float, int] and name not in ['dumpfreq', 'pddumpfreq', 'chkptfreq']:
object.__setattr__(self, name, Constant(value))
object.__setattr__(self, name, value)
# DO NOT MERGE
# Adjoint is a bit twitchy with Constants so let's not make them
# for now
# object.__setattr__(self, name, Constant(value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should use Function in Real space instead using Constant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, we have to have access to the mesh at this point for using Function in Real Space.

else:
object.__setattr__(self, name, value)

Expand Down
7 changes: 4 additions & 3 deletions gusto/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,16 @@ def __init__(self, mesh, dt, family, degree=None,
# -------------------------------------------------------------------- #

# Store central dt for use in the rest of the model
R = FunctionSpace(mesh, "R", 0)
if type(dt) is Constant:
self.dt = dt
self.dt = Function(R, val=float(dt))
elif type(dt) in (float, int):
self.dt = Constant(dt)
self.dt = Function(R, val=dt)
else:
raise TypeError(f'dt must be a Constant, float or int, not {type(dt)}')

# Make a placeholder for the time
self.t = Constant(0.0)
self.t = Function(R, val=0.0)

# -------------------------------------------------------------------- #
# Build compatible function spaces
Expand Down
12 changes: 7 additions & 5 deletions gusto/physics/shallow_water_microphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from firedrake import (
Interpolator, conditional, Function, dx, min_value, max_value, Constant
Interpolator, conditional, Function, dx, min_value, max_value, FunctionSpace
)
from firedrake.fml import subject
from gusto.core.logging import logger
Expand Down Expand Up @@ -93,13 +93,14 @@ def __init__(self, equation, saturation_curve,
self.water_v = Function(Vv)
self.source = Function(Vv)

R = FunctionSpace(equation.domain.mesh, "R", 0)
# tau is the timescale for conversion (may or may not be the timestep)
if tau is not None:
self.set_tau_to_dt = False
self.tau = tau
self.tau = Function(R).assign(tau)
else:
self.set_tau_to_dt = True
self.tau = Constant(0)
self.tau = Function(R)
logger.info("Timescale for rain conversion has been set to dt. If this is not the intention then provide a tau parameter as an argument to InstantRain.")

if self.time_varying_saturation:
Expand Down Expand Up @@ -269,12 +270,13 @@ def __init__(self, equation, saturation_curve,
V_idxs.append(self.Vb_idx)

# tau is the timescale for condensation/evaporation (may or may not be the timestep)
R = FunctionSpace(equation.domain.mesh, "R", 0)
if tau is not None:
self.set_tau_to_dt = False
self.tau = tau
self.tau = Function(R).assign(tau)
else:
self.set_tau_to_dt = True
self.tau = Constant(0)
self.tau = Function(R)
logger.info("Timescale for moisture conversion between vapour and cloud has been set to dt. If this is not the intention then provide a tau parameter as an argument to SWSaturationAdjustment.")

if self.time_varying_saturation:
Expand Down
3 changes: 2 additions & 1 deletion gusto/spatial_methods/diffusion_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,5 @@ def __init__(self, equation, variable, diffusion_parameters):
kappa = diffusion_parameters.kappa
self.form = diffusion(kappa * self.test.dx(0) * self.field.dx(0) * dx)
else:
raise NotImplementedError("CG diffusion only implemented in 1D")
kappa = diffusion_parameters.kappa
self.form = diffusion(kappa * inner(grad(self.test), grad(self.field)) * dx)
8 changes: 3 additions & 5 deletions gusto/time_discretisation/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math

from firedrake import (Function, TestFunction, TestFunctions, DirichletBC,
Constant, NonlinearVariationalProblem,
NonlinearVariationalProblem,
NonlinearVariationalSolver)
from firedrake.fml import (replace_subject, replace_test_function, Term,
all_terms, drop)
Expand Down Expand Up @@ -71,10 +71,8 @@ def __init__(self, domain, field_name=None, subcycling_options=None,
self.field_name = field_name
self.equation = None

self.dt = Constant(0.0)
self.dt.assign(domain.dt)
self.original_dt = Constant(0.0)
self.original_dt.assign(self.dt)
self.dt = Function(domain.dt.function_space(), val=domain.dt)
self.original_dt = Function(domain.dt.function_space(), val=self.dt)
self.options = options
self.limiter = limiter
self.courant_max = None
Expand Down
2 changes: 1 addition & 1 deletion gusto/timestepping/timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def run(self, t, tmax, pick_up=False):

self.timestep()

self.t.assign(self.t + self.dt)
self.t.assign(float(self.t) + float(self.dt))
self.step += 1

with timed_stage("Dump output"):
Expand Down
78 changes: 78 additions & 0 deletions integration-tests/adjoints/test_diffusion_sensitivity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import numpy as np

from firedrake import *
from firedrake.adjoint import *
from pyadjoint import get_working_tape
from gusto import *


@pytest.fixture(autouse=True)
def handle_taping():
yield
tape = get_working_tape()
tape.clear_tape()


@pytest.fixture(autouse=True, scope="module")
def handle_annotation():
from firedrake.adjoint import annotate_tape, continue_annotation
if not annotate_tape():
continue_annotation()
yield
# Ensure annotation is paused when we finish.
annotate = annotate_tape()
if annotate:
pause_annotation()


@pytest.mark.parametrize("nu_is_control", [True, False])
def test_diffusion_sensitivity(nu_is_control, tmpdir):
assert get_working_tape()._blocks == []
n = 30
mesh = PeriodicUnitSquareMesh(n, n)
output = OutputParameters(dirname=str(tmpdir))
dt = 0.01
domain = Domain(mesh, 10*dt, family="BDM", degree=1)
io = IO(domain, output)

V = VectorFunctionSpace(mesh, "CG", 2)
domain.spaces.add_space("vecCG", V)

R = FunctionSpace(mesh, "R", 0)
# We need to define nu as a function in order to have a control variable.
nu = Function(R, val=0.0001)
diffusion_params = DiffusionParameters(kappa=nu)
eqn = DiffusionEquation(domain, V, "f", diffusion_parameters=diffusion_params)

diffusion_scheme = BackwardEuler(domain)
diffusion_methods = [CGDiffusion(eqn, "f", diffusion_params)]
timestepper = Timestepper(eqn, diffusion_scheme, io, spatial_methods=diffusion_methods)

x = SpatialCoordinate(mesh)
fexpr = as_vector((sin(2*pi*x[0]), cos(2*pi*x[1])))
timestepper.fields("f").interpolate(fexpr)

end = 0.1
timestepper.run(0., end)

u = timestepper.fields("f")
J = assemble(inner(u, u)*dx)

if nu_is_control:
control = Control(nu)
h = Function(R, val=0.0001) # the direction of the perturbation
else:
control = Control(u)
# the direction of the perturbation
h = Function(V).interpolate(fexpr * np.random.rand())

# the functional as a pure function of nu
Jhat = ReducedFunctional(J, control)

if nu_is_control:
assert np.allclose(J, Jhat(nu))
assert taylor_test(Jhat, nu, h) > 1.95
else:
assert np.allclose(J, Jhat(u))
assert taylor_test(Jhat, u, h) > 1.95
Loading