Skip to content

Commit

Permalink
speed and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jun 15, 2024
1 parent a686600 commit 8a169ce
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 71 deletions.
13 changes: 8 additions & 5 deletions lib/gpt/algorithms/group/locally_coherent_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import gpt as g
from gpt.core.group import differentiable_functional


class locally_coherent_functional(differentiable_functional):
def __init__(self, inner, block):
self.inner = inner
Expand All @@ -29,10 +30,10 @@ def reduce(self, fields):
assert n % 2 == 0
n //= 2

right = self.block.embed(fields[n:2*n])
right = self.block.embed(fields[n : 2 * n])

return [g(g.group.compose(x, y)) for x, y in zip(fields[0:n], right)]

return [g(g.group.compose(x,y)) for x, y in zip(fields[0:n], right)]

def __call__(self, fields):
return self.inner(self.reduce(fields))

Expand All @@ -41,14 +42,16 @@ def gradient(self, fields, dfields):
assert n % 2 == 0
n //= 2
left = fields[0:n]
right = fields[n:2*n]
right = fields[n : 2 * n]

# f(left right)
# left derivative is like original: f(idA left right)
# right derivative is: f(left idA right ) = f(idA2 left right)
# with dA2 = left dA left^dag

indices = [mu for mu in range(n) if dfields[mu] in fields or dfields[mu + n] in fields]
indices = [
mu for mu in range(n) if dfields[mu] in fields or dfields[mu + n] in fields
]

r = self.reduce(fields)

Expand Down
14 changes: 10 additions & 4 deletions lib/gpt/algorithms/group/symmetric_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import gpt as g
from gpt.core.group import differentiable_functional


class symmetric_functional(differentiable_functional):
def __init__(self, inner):
self.inner = inner
Expand All @@ -27,8 +28,11 @@ def reduce(self, fields):
n = len(fields)
assert n % 2 == 0
n //= 2
return [g(g.group.compose(x, g.group.inverse(y))) for x, y in zip(fields[0:n], fields[n:2*n])]

return [
g(g.group.compose(x, g.group.inverse(y)))
for x, y in zip(fields[0:n], fields[n : 2 * n])
]

def __call__(self, fields):
return self.inner(self.reduce(fields))

Expand All @@ -37,14 +41,16 @@ def gradient(self, fields, dfields):
assert n % 2 == 0
n //= 2
left = fields[0:n]
right = fields[n:2*n]
right = fields[n : 2 * n]

# f(left right^dag)
# left derivative is like original: f(idA left right^dag)
# right derivative is: f(left right^dag (-i)dA ) = f(idA2 left right^dag)
# with dA2 = -left right^dag dA right left^dag = -r dA r^dag

indices = [mu for mu in range(n) if dfields[mu] in fields or dfields[mu + n] in fields]
indices = [
mu for mu in range(n) if dfields[mu] in fields or dfields[mu + n] in fields
]

r = self.reduce(fields)

Expand Down
4 changes: 2 additions & 2 deletions lib/gpt/core/checkerboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def str_to_cb(s):
assert 0


def pick_checkerboard(cb, dst, src = None):
def pick_checkerboard(cb, dst, src=None):
if src is None:
src = dst
dst = gpt.lattice(src.grid.checkerboarded(gpt.redblack), src.otype)
pick_checkerboard(cb, dst, src)
return dst

assert len(src.v_obj) == len(dst.v_obj)
for i in src.otype.v_idx:
cgpt.lattice_pick_checkerboard(cb.tag, src.v_obj[i], dst.v_obj[i])
Expand Down
155 changes: 96 additions & 59 deletions lib/gpt/qcd/gauge/smear/local_stout.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

local_stout_parallel_projector = g.default.get_int("--local-stout-parallel-projector", 1)

plaquette_stencil_cache = {}


def create_adjoint_projector(D, B, generators, nfactors):
ng = len(generators)
Expand Down Expand Up @@ -377,6 +379,25 @@ def __init__(self, stout):
self.stout = stout
self.verbose = stout.verbose

def plaquette_stencil(self, U, rho, mu, nu):
key = f"{U[0].grid.describe()}_{U[0].otype.__name__}_{mu}_{nu}_{rho}"
if key not in self.stout.cache:
code = []
code.append((0, -1, -rho, g.path().f(nu).f(mu).b(nu).b(mu)))
code.append((1, -1, -rho, g.path().f(nu).b(mu)))
code.append((2, -1, -rho, g.path().f(mu).b(nu).b(mu)))
code.append((3, -1, -rho, g.path().f(mu).f(nu).b(mu)))
code.append((4, -1, rho, g.path().f(nu).b(mu).b(nu)))
code.append((5, -1, rho, g.path().f(mu).f(nu).b(mu)))
code.append((6, -1, 1.0, g.path().b(mu)))
code.append((7, -1, 1.0, g.path().b(nu)))
code.append((8, -1, 1.0, g.path().f(nu)))
code.append((9, -1, 1.0, g.path().b(mu).f(nu)))

self.stout.cache[key] = g.parallel_transport_matrix(U, code, 10)

return self.stout.cache[key](U)

def __call__(self, U):
log_det = g.sum(self.stout.log_det_jacobian(U))
return -log_det.real
Expand Down Expand Up @@ -406,20 +427,34 @@ def gradient(self, U, dU):
t = g.timer("action_log_det_jacobian")
t("dJdX")

# dJdX
dJdX = [g(1j * adjoint_generators[b] * one) for b in range(ng)]
# dJdX -> stencil version
ag_fields = [g(1j * adjoint_generators[b] * one) for b in range(ng)]
dJdX = g.copy(ag_fields)
aunit = g.identity(J_ac)

X = g.copy(Z_ac)
t2 = g.copy(X)
t3 = g.lattice(t2)

code = []
_X = 0
_t2 = 1
_t3 = 2
_aunit = 3
_dJdX = [4 + b for b in range(ng)]
_adjoint_generators = [4 + ng + b for b in range(ng)]
for j in reversed(range(2, 13)):
t3 = g(t2 * (1 / (j + 1)) + aunit)
t2 @= X * t3
code.append((_t3, _aunit, 1 / (j + 1), [(_t2, 0, 0)]))
code.append((_t2, -1, 1, [(_X, 0, 0), (_t3, 0, 0)]))
for b in range(ng):
dJdX[b] @= 1j * adjoint_generators[b] * t3 + X * dJdX[b] * (1 / (j + 1))
code.append((_dJdX[b], -1, 1 / (j + 1), [(_X, 0, 0), (_dJdX[b], 0, 0)]))
code.append((_dJdX[b], _dJdX[b], 1, [(_adjoint_generators[b], 0, 0), (_t3, 0, 0)]))

for b in range(ng):
dJdX[b] = g(-dJdX[b])
code.append((_dJdX[b], -1, -1, [(_dJdX[b], 0, 0)]))

dJdX_stencil = g.local_stencil.matrix(X, [tuple([0] * len(U))], code)
dJdX_stencil(X, t2, t3, aunit, *dJdX, *ag_fields)

t("invert M_ab")
inv_M_ab = g.matrix.inv(M_ab)
Expand All @@ -446,124 +481,126 @@ def gradient(self, U, dU):
dJdXe_nMpInv @= dJdXe_nMpInv * fm

mu = self.stout.params["dimension"]
U_mu_masked = g(U[mu] * fm)

# fundamental forces
Fdet1 = [g.lattice(grid, otype) for nu in range(len(U))]
Fdet2 = [g.lattice(grid, otype) for nu in range(len(U))]

t("non-local")
t("non-local cshift")
Nxy = g.lattice(NxxAd)

dJdXe_nMpInv_bmu = g.cshift(dJdXe_nMpInv, mu, -1)
MpInvJx_bmu = g.cshift(MpInvJx, mu, -1)

for nu in range(len(U)):
if nu == mu:
continue

t("non-local stencil")

(
minus_fnu_fmu_bnu_bmu,
minus_fnu_bmu,
minus_fmu_bnu_bmu,
minus_fmu_fnu_bmu,
plus_fnu_bmu_bnu,
plus_fmu_fnu_bmu,
one_bmu,
one_bnu,
one_fnu,
one_bmu_fnu,
) = self.plaquette_stencil(U, rho, mu, nu)

for cb_field in [minus_fnu_fmu_bnu_bmu, minus_fnu_bmu]:
cb_field @= cb_field * fm

for icb_field in [minus_fmu_bnu_bmu, minus_fmu_fnu_bmu, plus_fmu_fnu_bmu, one_bmu]:
icb_field @= icb_field * (one - fm)

dJdXe_nMpInv_bnu = g.cshift(dJdXe_nMpInv, nu, -1)
dJdXe_nMpInv_fnu = g.cshift(dJdXe_nMpInv, nu, 1)
dJdXe_nMpInv_fnu_bmu = g.cshift(dJdXe_nMpInv_bmu, nu, 1)
MpInvJx_fnu = g.cshift(MpInvJx, nu, 1)
MpInvJx_fnu_bmu = g.cshift(MpInvJx_bmu, nu, 1)
MpInvJx_bnu = g.cshift(MpInvJx, nu, -1)

# + nu cw
PlaqL = g.identity(U[0])
PlaqR = g((-rho) * csf(U[nu], nu, csf(U[mu], mu, csb(U[nu], nu, csb(U_mu_masked, mu)))))
PlaqR = minus_fnu_fmu_bnu_bmu

dJdXe_nMpInv_y = dJdXe_nMpInv
t("compute_adj_ab")
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
t("non-local")
Fdet1_nu = g(g.transpose(Nxy) * dJdXe_nMpInv_y)
Fdet1_nu = g(g.transpose(Nxy) * dJdXe_nMpInv)

PlaqR = g((-1.0) * PlaqR)
t("compute_adj_abc")

compute_adj_abc(PlaqL, PlaqR, MpInvJx, FdetV, generators, cache, cb)
t("non-local")
Fdet2_nu = g.copy(FdetV)

# + nw acw
PlaqR = g(rho * csf(U[nu], nu, csb(U[mu], mu, csb(U[nu], nu))))
PlaqL = g(csb(U_mu_masked, mu))
PlaqR = plus_fnu_bmu_bnu
PlaqL = one_bmu

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, mu, -1)
t("compute_adj_ab")
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
t("non-local")
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_y
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_bmu

MpInvJx_nu = g.cshift(MpInvJx, mu, -1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_bmu, FdetV, generators, cache, cb.inv())
Fdet2_nu += FdetV

# - nu cw
PlaqL = g(rho * csf(U[mu], mu, csf(U[nu], nu, csb(U_mu_masked, mu))))
PlaqR = g(csf(U[nu], nu))
PlaqL = plus_fmu_fnu_bmu
PlaqR = one_fnu

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, 1)
t("compute_adj_ab")
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
t("non-local")
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_y
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_fnu

MpInvJx_nu = g.cshift(MpInvJx, nu, 1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_fnu, FdetV, generators, cache, cb.inv())
Fdet2_nu += FdetV

# -nu acw
PlaqL = g((-rho) * csf(U[nu], nu, csb(U_mu_masked, mu)))
PlaqR = csb(U[mu], mu, csf(U[nu], nu))

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, mu, -1)
dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv_y, nu, 1)
PlaqL = minus_fnu_bmu
PlaqR = one_bmu_fnu

t("compute_adj_ab")
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
t("non-local")
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_y
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_fnu_bmu

MpInvJx_nu = g.cshift(MpInvJx, mu, -1)
MpInvJx_nu = g.cshift(MpInvJx_nu, nu, 1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb)
t("non-local")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_fnu_bmu, FdetV, generators, cache, cb)
Fdet2_nu += FdetV

# force contributions to fundamental representation
t("adj_to_fund")
adjoint_to_fundamental(Fdet1[nu], Fdet1_nu, generators)
adjoint_to_fundamental(Fdet2[nu], Fdet2_nu, generators)
t("non-local")

# mu cw
PlaqL = g((-rho) * csf(U[mu], mu, csb(U[nu], nu, csb(U_mu_masked, mu))))
PlaqR = g(csb(U[nu], nu))
PlaqL = minus_fmu_bnu_bmu
PlaqR = one_bnu

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, -1)
t("compute_adj_ab")
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
t("non-local")
Fdet1_mu += g.transpose(Nxy) * dJdXe_nMpInv_y
Fdet1_mu += g.transpose(Nxy) * dJdXe_nMpInv_bnu

MpInvJx_nu = g.cshift(MpInvJx, nu, -1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_bnu, FdetV, generators, cache, cb.inv())
Fdet2_mu += FdetV

# mu acw
PlaqL = g((-rho) * csf(U[mu], mu, csf(U[nu], nu, csb(U_mu_masked, mu))))
PlaqR = g(csf(U[nu], nu))

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, 1)
PlaqL = minus_fmu_fnu_bmu
PlaqR = one_fnu

t("compute_adj_ab")
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
t("non-local")
Fdet1_mu += g.transpose(Nxy) * dJdXe_nMpInv_y

MpInvJx_nu = g.cshift(MpInvJx, nu, 1)
Fdet1_mu += g.transpose(Nxy) * dJdXe_nMpInv_fnu

t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_fnu, FdetV, generators, cache, cb.inv())
Fdet2_mu += FdetV

t("aggregate")
Expand Down
2 changes: 1 addition & 1 deletion scripts/black
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ for f in ${FILES}
do

python3 -m black -t py36 --line-length 100 ${f}
python3 -m flake8 ${f}
python3 -m flake8 --ignore=E401,E226 ${f}
if [[ "$?" != "0" ]];
then
echo "Need to fix $f"
Expand Down

0 comments on commit 8a169ce

Please sign in to comment.