Skip to content

Commit

Permalink
speed
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jun 3, 2024
1 parent 9d8aad0 commit 1acdf33
Showing 1 changed file with 101 additions and 75 deletions.
176 changes: 101 additions & 75 deletions lib/gpt/qcd/gauge/smear/local_stout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
from gpt.core.group import local_diffeomorphism, differentiable_functional


def compute_adj_ab(A, B, C, generators):
# factors checked
ng = len(generators)
tmp = {}
for b in range(ng):
N_b = 2 * g.qcd.gauge.project.traceless_anti_hermitian(g.adj(A) * 1j * generators[b] * B)
for c in range(ng):
tmp[c, b] = g(-g.trace(1j * generators[c] * N_b))
g.merge_color(C, tmp)
# def compute_adj_ab_slow(A, B, C, generators, cache):
# # factors checked
# ng = len(generators)
# tmp = {}
# for b in range(ng):
# N_b = 2 * g.qcd.gauge.project.traceless_anti_hermitian(g.adj(A) * 1j * generators[b] * B)
# for c in range(ng):
# tmp[c, b] = g(-g.trace(1j * generators[c] * N_b))
# g.merge_color(C, tmp)


# def adjoint_from_right(D, UtaU, generators):
Expand All @@ -45,80 +45,104 @@ def compute_adj_ab(A, B, C, generators):
# g.merge_color(D, tmp)


def adjoint_from_right_fast(D, UtaU, generators, cache):
# is a factor of 10 faster than the above
def create_adjoint_projector(D, B, generators, nfactors):
ng = len(generators)
if "stencil" not in cache:
code = []
idst = 0
iUtaU = 1
igen = 2
itmp1 = -(igen + ng)
itmp2 = -(igen + ng + 1)
ndim = UtaU.otype.shape[0]
ti = g.stencil.tensor_instructions
for c in range(ng):
# itmp1 = 2j * generators[c] * UtaU
code = []
idst = 0
if nfactors == 1:
iB = 1
else:
iA = 1
iB = 2
igen = iB + 1
itmp1 = -(igen + ng)
itmp2 = -(igen + ng + 1)
ndim = B.otype.shape[0]
ti = g.stencil.tensor_instructions
for c in range(ng):
# itmp1 = 2j * generators[c] * B
imm = itmp1 if nfactors == 1 else itmp2
for ia in range(ndim):
for ib in range(ndim):
dst = ia * ndim + ib
for ic in range(ndim):
aa = ia * ndim + ic
bb = ic * ndim + ib
mode = ti.mov if ic == 0 else ti.inc
code.append((imm, dst, mode, 1.0, [(igen + c, 0, aa), (iB, 0, bb)]))
code.append((imm, dst, ti.mul, 2j, [(imm, 0, dst)]))
if nfactors == 2:
for ia in range(ndim):
for ib in range(ndim):
dst = ia * ndim + ib
for ic in range(ndim):
aa = ia * ndim + ic
bb = ic * ndim + ib
mode = ti.mov if ic == 0 else ti.inc
code.append((itmp1, dst, mode, 1.0, [(igen + c, 0, aa), (iUtaU, 0, bb)]))
code.append((itmp1, dst, ti.mul, 2j, [(itmp1, 0, dst)]))
# itmp2 = 0.5 * itmp1 - 0.5 * adj(itmp1)
code.append((itmp1, dst, mode, 1.0, [(iA, 0, aa), (itmp2, 0, bb)]))
# itmp2 = 0.5 * itmp1 - 0.5 * adj(itmp1)
for ia in range(ndim):
for ib in range(ndim):
dst = ia * ndim + ib
dst_adj = ib * ndim + ia
code.append((itmp2, dst, ti.mov, 1.0, [(itmp1, 0, dst)]))
code.append((itmp2, dst, ti.dec_cc, 1.0, [(itmp1, 0, dst_adj)]))
code.append((itmp2, dst, ti.mul, 0.5, [(itmp2, 0, dst)]))
# itmp1[0,0] = g.trace(itmp2) / 3.0
for ia in range(ndim):
mode = ti.mov if ia == 0 else ti.inc
src = ia * ndim + ia
code.append((itmp1, 0, mode, 1.0, [(itmp2, 0, src)]))
code.append((itmp1, 0, ti.mul, 1.0 / 3.0, [(itmp1, 0, 0)]))

# itmp2[i,i] -= itmp1[0,0]
for ia in range(ndim):
src = ia * ndim + ia
code.append((itmp2, src, ti.dec, 1.0, [(itmp1, 0, 0)]))

# itmp2 = g.qcd.gauge.project.traceless_anti_hermitian(2j * generators[c] * B) at this point

# now Dprime[d, c] = g(-g.trace(1j * generators[d] * itmp2))
for d in range(ng):
mode = ti.mov
dst = d * ng + c
for ia in range(ndim):
for ib in range(ndim):
dst = ia * ndim + ib
dst_adj = ib * ndim + ia
code.append((itmp2, dst, ti.mov, 1.0, [(itmp1, 0, dst)]))
code.append((itmp2, dst, ti.dec_cc, 1.0, [(itmp1, 0, dst_adj)]))
code.append((itmp2, dst, ti.mul, 0.5, [(itmp2, 0, dst)]))
# itmp1[0,0] = g.trace(itmp2) / 3.0
for ia in range(ndim):
mode = ti.mov if ia == 0 else ti.inc
src = ia * ndim + ia
code.append((itmp1, 0, mode, 1.0, [(itmp2, 0, src)]))
code.append((itmp1, 0, ti.mul, 1.0 / 3.0, [(itmp1, 0, 0)]))
aa = ia * ndim + ib
bb = ib * ndim + ia
code.append((idst, dst, mode, 1.0, [(igen + d, 0, aa), (itmp2, 0, bb)]))
mode = ti.inc
code.append((idst, dst, ti.mul, -1j, [(idst, 0, dst)]))

# itmp2[i,i] -= itmp1[0,0]
for ia in range(ndim):
src = ia * ndim + ia
code.append((itmp2, src, ti.dec, 1.0, [(itmp1, 0, 0)]))

# itmp2 = g.qcd.gauge.project.traceless_anti_hermitian(2j * generators[c] * UtaU) at this point

# now Dprime[d, c] = g(-g.trace(1j * generators[d] * itmp2))
for d in range(ng):
mode = ti.mov
dst = d * ng + c
for ia in range(ndim):
for ib in range(ndim):
aa = ia * ndim + ib
bb = ib * ndim + ia
code.append((idst, dst, mode, 1.0, [(igen + d, 0, aa), (itmp2, 0, bb)]))
mode = ti.inc
code.append((idst, dst, ti.mul, -1j, [(idst, 0, dst)]))

segments = [(len(code) // 1, 1)]
ein = g.stencil.tensor(D, [(0, 0, 0, 0)], code, segments)

fgenerators = [g.lattice(UtaU) for d in range(ng + 2)]
for d in range(ng):
fgenerators[d][:] = generators[d]
segments = [(len(code) // 1, 1)]
ein = g.stencil.tensor(D, [(0, 0, 0, 0)], code, segments)

cache["stencil"] = ein, fgenerators
fgenerators = [g.lattice(B) for d in range(ng + 2)]
for d in range(ng):
fgenerators[d][:] = generators[d]

else:
ein, fgenerators = cache["stencil"]
return ein, fgenerators


def adjoint_from_right_fast(D, UtaU, generators, cache):
if "stencil" not in cache:
cache["stencil"] = create_adjoint_projector(D, UtaU, generators, 1)

ein, fgenerators = cache["stencil"]

ein(D, UtaU, *fgenerators)


def compute_adj_ab(A, B, C, generators, cache):
if "stencil_ab" not in cache:
cache["stencil_ab"] = create_adjoint_projector(C, A, generators, 2)

ein, fgenerators = cache["stencil_ab"]

ein(C, g(g.adj(A)), g(B), *fgenerators)


def compute_adj_abc(A, B, C, V, generators, cache):
# this is the bottle neck computationally right now ; TODO: make faster
t = g.timer("compute_adj_abc")
t("other")
ng = len(generators)
Expand Down Expand Up @@ -295,7 +319,7 @@ def jacobian(self, fields, fields_prime, src):

return dst

def jacobian_components(self, fields):
def jacobian_components(self, fields, cache_ab):
C_mu, U, fm = self.get_C(fields)
mu = self.params["dimension"]

Expand All @@ -319,7 +343,7 @@ def jacobian_components(self, fields):
adj_id = g.identity(g.lattice(grid, adjoint_otype))
fund_id = g.identity(g.lattice(grid, otype))

compute_adj_ab(fund_id, M, N_cb, generators)
compute_adj_ab(fund_id, M, N_cb, generators, cache_ab)

Z = g(g.qcd.gauge.project.traceless_anti_hermitian(g.adj(M)))

Expand Down Expand Up @@ -351,7 +375,8 @@ def jacobian_components(self, fields):
return J_ac, N_cb, Z_ac, M, fm, M_ab

def log_det_jacobian(self, fields):
J_ac, N_cb, Z_ac, M, fm, M_ab = self.jacobian_components(fields)
cache_ab = {}
J_ac, N_cb, Z_ac, M, fm, M_ab = self.jacobian_components(fields, cache_ab)
det_M = g.matrix.det(M_ab)
log_det_M = g(g.component.real(g.component.log(det_M)))
log_det = g(fm * log_det_M)
Expand All @@ -373,7 +398,8 @@ def __call__(self, U):
def gradient(self, U, dU):
assert dU == U

J_ac, NxxAd, Z_ac, M, fm, M_ab = self.stout.jacobian_components(U)
cache_ab = {}
J_ac, NxxAd, Z_ac, M, fm, M_ab = self.stout.jacobian_components(U, cache_ab)

grid = J_ac.grid
dtype = grid.precision.complex_dtype
Expand Down Expand Up @@ -448,7 +474,7 @@ def gradient(self, U, dU):
PlaqR = g((-rho) * csf(U[nu], nu, csf(U[mu], mu, csb(U[nu], nu, csb(U_mu_masked, mu)))))

dJdXe_nMpInv_y = dJdXe_nMpInv
compute_adj_ab(PlaqL, PlaqR, Nxy, generators)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
Fdet1_nu = g(g.transpose(Nxy) * dJdXe_nMpInv_y)

PlaqR = g((-1.0) * PlaqR)
Expand All @@ -462,7 +488,7 @@ def gradient(self, U, dU):
PlaqL = csb(U_mu_masked, mu)

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, mu, -1)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_y

MpInvJx_nu = g.cshift(MpInvJx, mu, -1)
Expand All @@ -476,7 +502,7 @@ def gradient(self, U, dU):
PlaqR = csf(U[nu], nu)

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, 1)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_y

MpInvJx_nu = g.cshift(MpInvJx, nu, 1)
Expand All @@ -492,7 +518,7 @@ def gradient(self, U, dU):
dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, mu, -1)
dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv_y, nu, 1)

compute_adj_ab(PlaqL, PlaqR, Nxy, generators)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
Fdet1_nu += g.transpose(Nxy) * dJdXe_nMpInv_y

MpInvJx_nu = g.cshift(MpInvJx, mu, -1)
Expand All @@ -511,7 +537,7 @@ def gradient(self, U, dU):
PlaqR = csb(U[nu], nu)

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, -1)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
Fdet1_mu += g.transpose(Nxy) * dJdXe_nMpInv_y

MpInvJx_nu = g.cshift(MpInvJx, nu, -1)
Expand All @@ -526,7 +552,7 @@ def gradient(self, U, dU):

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, 1)

compute_adj_ab(PlaqL, PlaqR, Nxy, generators)
compute_adj_ab(PlaqL, PlaqR, Nxy, generators, cache_ab)
Fdet1_mu += g.transpose(Nxy) * dJdXe_nMpInv_y

MpInvJx_nu = g.cshift(MpInvJx, nu, 1)
Expand Down

0 comments on commit 1acdf33

Please sign in to comment.