Skip to content

Commit

Permalink
general differentiable field transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jul 21, 2024
1 parent bdf6778 commit e86e018
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 2 deletions.
7 changes: 6 additions & 1 deletion lib/gpt/ad/reverse/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def __init__(
self._container = _container
self._backward = _backward
self._children = _children
if len(_children) > 0:
with_gradient = any([c.with_gradient for c in _children])
self.with_gradient = with_gradient
self.infinitesimal_to_cartesian = infinitesimal_to_cartesian
self.gradient = None
Expand Down Expand Up @@ -236,7 +238,10 @@ def __add__(x, y):
if not isinstance(y, node_base):
y = node_base(y, with_gradient=False)

assert x._container == y._container
if not x._container.accumulate_compatible(y._container):
raise Exception(
f"Containers incompatible in addition: {x._container} and {y._container}"
)
_container = x._container

def _forward():
Expand Down
19 changes: 18 additions & 1 deletion lib/gpt/ad/reverse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ def set_otype(self, otype):
else:
raise Exception("Container does not have an otype")

def accumulate_compatible(self, other):
if len(self.tag) > 1 and len(other.tag) > 1:
if len(self.tag) != len(other.tag):
return False
if len(self.tag) > 2:
if self.get_grid().obj != other.get_grid().obj:
return False
a = self.get_otype()
b = other.get_otype()
if a.data_alias is not None:
a = a.data_alias()
if b.data_alias is not None:
b = b.data_alias()
return a.__name__ == b.__name__

return self.__eq__(other)

def zero(self):
r = self.representative()
if isinstance(r, g.lattice):
Expand Down Expand Up @@ -142,7 +159,7 @@ def convert_container(v, x, y, operand):
r = g.expr(operand(rx, ry))
c = container(r.container())

if v._container == c:
if v._container.accumulate_compatible(c):
return v

# conversions from tensor to matrix
Expand Down
1 change: 1 addition & 0 deletions lib/gpt/qcd/gauge/smear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from gpt.qcd.gauge.smear.stout import stout, differentiable_stout
from gpt.qcd.gauge.smear.local_stout import local_stout
from gpt.qcd.gauge.smear.wilson_flow import wilson_flow
from gpt.qcd.gauge.smear.differentiable import differentiable_field_transformation
132 changes: 132 additions & 0 deletions lib/gpt/qcd/gauge/smear/differentiable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#
# GPT - Grid Python Toolkit
# Copyright (C) 2020-24 Christoph Lehner ([email protected], https://github.com/lehner/gpt)
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt as g
from gpt.core.group import diffeomorphism, differentiable_functional


class dft_diffeomorphism(diffeomorphism):
def __init__(self, U, ft):
rad = g.ad.reverse
self.ft = ft
self.aU = [rad.node(u.new()) for u in U]
self.aUft = ft(self.aU)

def __call__(self, fields):
res = self.ft(fields)
return [g(x) for x in res]

def jacobian(self, fields, fields_prime, dfields):
N = len(fields_prime)
assert len(fields) == N
assert len(fields) == N
assert len(fields) == N
aU_prime = [g(2j * dfields[mu] * fields_prime[mu]) for mu in range(N)]
for mu in range(N):
self.aU[mu].value = fields[mu]
gradient = [None] * N
for mu in range(N):
self.aUft[mu](initial_gradient=aU_prime[mu])
for nu in range(N):
if gradient[nu] is None:
gradient[nu] = self.aU[nu].gradient
else:
gradient[nu] = g(gradient[nu] + self.aU[nu].gradient)

return gradient


class dft_action_log_det_jacobian(differentiable_functional):
def __init__(self, U, ft, dfm, inverter):
self.dfm = dfm
self.inverter = inverter
self.N = len(U)
mom = [g.group.cartesian(u) for u in U]
rad = g.ad.reverse

_U = [rad.node(g.copy(u)) for u in U]
_mom = [rad.node(g.copy(u)) for u in mom]
_Up = dfm(_U)
momp = dfm.jacobian(_U, _Up, _mom)

act = None
for mu in range(self.N):
if mu == 0:
act = g.norm2(momp[mu])
else:
act = g(act + g.norm2(momp[mu]))

self.action = act.functional(*(_U + _mom))

def __call__(self, fields):
return self.action(fields)

def gradient(self, fields, dfields):
return self.action.gradient(fields, dfields)

def draw(self, fields, rng):
U = fields[0 : self.N]
mom = fields[self.N :]
assert len(mom) == self.N
assert len(U) == self.N

rng.normal_element(mom, scale=1.0)

U_prime = self.dfm(U)

def _mat(dst_5d, src_5d):
src = g.separate(src_5d, dimension=0)
dst = self.dfm.jacobian(U, U_prime, src)
dst_5d @= g.merge(dst, dimension=0)

mom_xd = g.merge(mom, dimension=0)

mom_prime_xd = self.inverter(_mat)(mom_xd)
mom_prime = g.separate(mom_prime_xd, dimension=0)

act = 0.0
for mu in range(self.N):
act += g.norm2(mom[mu])
mom[mu] @= mom_prime[mu]

return act


class differentiable_field_transformation:
def __init__(self, U, ft, inverter, optimizer):
self.ft = ft
self.U = U
self.dfm = dft_diffeomorphism(self.U, self.ft)
self.inverter = inverter
self.optimizer = optimizer

def diffeomorphism(self):
return self.dfm

def inverse(self, Uft):
rad = g.ad.reverse
aU = [rad.node(g.copy(u)) for u in Uft]
aUft_target = [rad.node(u, with_gradient=False) for u in Uft]
aUft = self.ft(aU)
fnc = sum([g.norm2(aUft_target[mu] - aUft[mu]) for mu in range(len(Uft))]).functional(*aU)
U = g.copy(Uft)
self.optimizer(fnc)(U, U)
return U

def action_log_det_jacobian(self):
return dft_action_log_det_jacobian(self.U, self.ft, self.dfm, self.inverter)
11 changes: 11 additions & 0 deletions tests/ad/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ def real(x):
g.message("Force test", mu, eps2)
assert eps2 < 1e-20

# keep some links unchanged, save time in gradient calculation
for u in U_2[1:]:
u.with_gradient = False

A = g.qcd.gauge.action.differentiable_iwasaki(2.5)(U_2)
a1p = A.functional(*U_2)
a1p.assert_gradient_error(rng, U, [U[0]], 1e-4, 1e-10)

for u in U_2:
u.with_gradient = True


#####################################
# forward AD tests
Expand Down
63 changes: 63 additions & 0 deletions tests/qcd/gauge.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,66 @@
eps2 += g.norm2(U[nu] - U0[nu]) / g.norm2(U0[nu])
g.message(eps2)
assert eps2 < 1e-28


# test general differentiable field transformation framework
ft_stout = g.qcd.gauge.smear.differentiable_stout(rho=0.05)

fr = g.algorithms.optimize.fletcher_reeves
ls2 = g.algorithms.optimize.line_search_quadratic

dft = g.qcd.gauge.smear.differentiable_field_transformation(
U,
ft_stout,
# g.algorithms.inverter.fgmres(eps=1e-15, maxiter=1000, restartlen=60),
g.algorithms.inverter.fgcr(eps=1e-15, maxiter=1000, restartlen=60),
g.algorithms.optimize.non_linear_cg(
maxiter=1000, eps=1e-15, step=1e-1, line_search=ls2, beta=fr
),
)

dfm = dft.diffeomorphism()
ald = dft.action_log_det_jacobian()

# test diffeomorphism of stout against reference implementation
dfm_ref = g.qcd.gauge.smear.stout(rho=0.05)
Uft = dfm(U)
Uft_ref = dfm_ref(U)
for mu in range(4):
eps2 = g.norm2(Uft[mu] - Uft_ref[mu]) / g.norm2(Uft[mu])
g.message("Test ft:", eps2)
assert eps2 < 1e-25

mom = [g.group.cartesian(u) for u in U]
mom_prime = g.copy(mom)
rng.normal_element(mom_prime)
t0 = g.time()
mom = dfm.jacobian(U, Uft, mom_prime)
t1 = g.time()
mom_ref = dfm_ref.jacobian(U, Uft, mom_prime)
t2 = g.time()
for mu in range(4):
eps2 = g.norm2(mom[mu] - mom_ref[mu]) / g.norm2(mom[mu])
g.message("Test jacobian:", eps2)
assert eps2 < 1e-25

g.message("Time for dfm.jacobian", t1 - t0, "seconds")
g.message("Time for dfm_ref.jacobian", t2 - t1, "seconds")

mom2 = g.copy(mom)
g.message("Action log det jac:", ald(U + mom2))

ald.assert_gradient_error(rng, U + mom2, U + mom2, 1e-3, 1e-7)

act = ald.draw(U + mom2, rng)
act2 = ald(U + mom2)
eps = abs(act / act2 - 1)
g.message("Draw from log det action:", eps)
assert eps < 1e-10

if True:
U0 = dft.inverse(Uft)
for mu in range(4):
eps2 = g.norm2(U0[mu] - U[mu]) / g.norm2(U[mu])
g.message("Test invertibility:", eps2)
assert eps2 < 1e-25

0 comments on commit e86e018

Please sign in to comment.