Skip to content

Commit

Permalink
fourier mass term
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed May 8, 2024
1 parent 3d931e0 commit 426854f
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 86 deletions.
2 changes: 1 addition & 1 deletion benchmarks/matrix_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# matrix multiply
nbytes = 3.0 * one.global_bytes() * N
n = (one.otype.nfloats // 2)**0.5
n = (one.otype.nfloats // 2) ** 0.5
flops_per_matrix_multiply = n * n * (n * 6 + (n - 1) * 2)
flops = flops_per_matrix_multiply = grid.gsites * N * flops_per_matrix_multiply

Expand Down
104 changes: 56 additions & 48 deletions benchmarks/stencil_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,114 +7,124 @@
m2 = g.mcolor(grid)
m3 = g.mcolor(grid)
rng = g.random("D")
rng.cnormal([m1,m2,m3])
m3ref = g(m1*m2)
rng.cnormal([m1, m2, m3])
m3ref = g(m1 * m2)
code = []
ti = g.stencil.tensor_instructions

for i in range(3):
for j in range(3):
for l in range(3):
dst = 3*i + j
dst = 3 * i + j
code.append(
(0,dst,ti.mov if l == 0 else ti.inc,1.0,[(1,0,3*i + l),(2,0,3*l + j)])
(0, dst, ti.mov if l == 0 else ti.inc, 1.0, [(1, 0, 3 * i + l), (2, 0, 3 * l + j)])
)

segments = [(3, 9)]
ein = g.stencil.tensor(m1, [(0, 0, 0, 0), (1, 0, 0, 0)], code, segments)

ein(m3,m1,m2)
ein(m3, m1, m2)
g.message("m3 = m1 * m2")
g.message(g.norm2(m3 - m3ref))

for osites_per_instruction in [4,16,32,128,256]: #[1,8,16,32,64]:
for osites_per_cache_block in [4096, 2**15, grid.gsites]: #[2**11, 2**13, 2**15, grid.gsites]:
for osites_per_instruction in [4, 16, 32, 128, 256]: # [1,8,16,32,64]:
for osites_per_cache_block in [
4096,
2**15,
grid.gsites,
]: # [2**11, 2**13, 2**15, grid.gsites]:
ein.memory_access_pattern(osites_per_instruction, osites_per_cache_block)

g.message(osites_per_instruction, osites_per_cache_block)
t=g.timer("d")
t = g.timer("d")
t("expr")
for i in range(300):
g.eval(m3,m1*m2)
g.eval(m3, m1 * m2)
t("stencil")
for i in range(300):
ein(m3,m1,m2)
ein(m3, m1, m2)
t()
g.message(t)
eps2 = g.norm2(m3 - m3ref) / g.norm2(m3)
assert eps2 < 1e-25
g.message(eps2)



# next
m4ref = g(m1*m2*m2)
m4ref = g(m1 * m2 * m2)
code = []
if True:
for i in range(3):
for j in range(3):
for l in range(3):
dst = 3*i + j
dst = 3 * i + j
code.append(
(-1,dst,ti.mov if l == 0 else ti.inc,1.0,[(3,0,3*i + l),(3,0,3*l + j)])
(
-1,
dst,
ti.mov if l == 0 else ti.inc,
1.0,
[(3, 0, 3 * i + l), (3, 0, 3 * l + j)],
)
)
for i in range(3):
for j in range(3):
for l in range(3):
dst = 3*i + j
dst = 3 * i + j
code.append(
(0,dst,ti.mov if l == 0 else ti.inc,1.0,[(2,0,3*i + l),(-1,0,3*l + j)])
(
0,
dst,
ti.mov if l == 0 else ti.inc,
1.0,
[(2, 0, 3 * i + l), (-1, 0, 3 * l + j)],
)
)
segments = [(3, 9), (3, 9)]
#segments = [(27*2, 1)]
# segments = [(27*2, 1)]
else:
for i in range(3):
for j in range(3):
dst = 3*i + j
code.append(
(-1,dst,ti.add,1.0,[(3,0,dst),(4,0,dst)])
)
code.append(
(0,dst,ti.add,1.0,[(2,0,dst),(-1,0,dst)])
)
m4ref = g(m1+m2+m3)
dst = 3 * i + j
code.append((-1, dst, ti.add, 1.0, [(3, 0, dst), (4, 0, dst)]))
code.append((0, dst, ti.add, 1.0, [(2, 0, dst), (-1, 0, dst)]))
m4ref = g(m1 + m2 + m3)
segments = [(9, 2)]


ein = g.stencil.tensor(m1, [(0, 0, 0, 0), (1, 0, 0, 0)], code, segments)

tmp = g.lattice(m1)
m4 = g.lattice(m1)
ein(m4,tmp,m1,m2,m3)
ein(m4, tmp, m1, m2, m3)
g.message("m4 = m1 * m2 * m3")
g.message(g.norm2(m4 - m4ref))

for osites_per_instruction in [16,32,64,256]:
for osites_per_cache_block in [16*16*16, 16*16*16*32, grid.gsites]:
for osites_per_instruction in [16, 32, 64, 256]:
for osites_per_cache_block in [16 * 16 * 16, 16 * 16 * 16 * 32, grid.gsites]:
ein.memory_access_pattern(osites_per_instruction, osites_per_cache_block)

g.message(osites_per_instruction, osites_per_cache_block)
t=g.timer("d")
t = g.timer("d")
t("expr")
for i in range(300):
g.eval(m4,m1*m2*m2)
g.eval(m4, m1 * m2 * m2)
t("stencil")
for i in range(300):
ein(m4,tmp,m1,m2,m3)
ein(m4, tmp, m1, m2, m3)
t()
g.message(t)
eps2 = g.norm2(m4 - m4ref) / g.norm2(m4)
assert eps2 < 1e-25
g.message(eps2)



g.message("Diquark")

# D_{a2,a1} = epsilon_{a1,b1,c1}*epsilon_{a2,b2,c2}*spin_transpose(Q1_{b1,b2})*Q2_{c1,c2}
Q1 = g.mspincolor(grid)
Q2 = g.mspincolor(grid)
rng.cnormal([Q1,Q2])
rng.cnormal([Q1, Q2])
eps = g.epsilon(Q1.otype.shape[2])
code = []
acc = {}
Expand All @@ -123,20 +133,18 @@
for l in range(4):
for i1, sign1 in eps:
for i2, sign2 in eps:
dst = (i*4 + j)*9 + i2[0]*3 + i1[0]
aa = (4*i + l)*9 + i1[1]*3 + i2[1]
bb = (4*j + l)*9 + i1[2]*3 + i2[2]
dst = (i * 4 + j) * 9 + i2[0] * 3 + i1[0]
aa = (4 * i + l) * 9 + i1[1] * 3 + i2[1]
bb = (4 * j + l) * 9 + i1[2] * 3 + i2[2]
if dst not in acc:
acc[dst] = True
mode = ti.mov if sign1 * sign2 > 0 else ti.mov_neg
else:
mode = ti.inc if sign1 * sign2 > 0 else ti.dec
assert dst >= 0 and dst < 12*12
assert aa >= 0 and aa < 12*12
assert bb >= 0 and bb < 12*12
code.append(
(0,dst,mode,1.0,[(1,0,aa),(2,0,bb)])
)
assert dst >= 0 and dst < 12 * 12
assert aa >= 0 and aa < 12 * 12
assert bb >= 0 and bb < 12 * 12
code.append((0, dst, mode, 1.0, [(1, 0, aa), (2, 0, bb)]))

g.message(len(code))
segments = [(len(code) // 16, 16)]
Expand All @@ -146,23 +154,23 @@
R[:] = 0
ein(R, Q1, Q2)

R2 = g.qcd.baryon.diquark(Q1,Q2)
R2 = g.qcd.baryon.diquark(Q1, Q2)

g.message(g.norm2(R - R2) / g.norm2(R))
#
# D[i2[0], i1[0]] += sign1 * sign2 * Q1[i1[1], i2[1]] * g.transpose(Q2[i1[2], i2[2]])
for osites_per_instruction in [1,2,4,16,32,64,256]:
for osites_per_cache_block in [ grid.gsites]:
for osites_per_instruction in [1, 2, 4, 16, 32, 64, 256]:
for osites_per_cache_block in [grid.gsites]:
ein.memory_access_pattern(osites_per_instruction, osites_per_cache_block)

g.message(osites_per_instruction, osites_per_cache_block)
t=g.timer("d")
t = g.timer("d")
t("diquark")
for i in range(30):
g.qcd.baryon.diquark(Q1,Q2)
g.qcd.baryon.diquark(Q1, Q2)
t("stencil")
for i in range(30):
ein(R, Q1, Q2)
t()
g.message(t)
g.message(g.norm2(R - R2) / g.norm2(R))
g.message(g.norm2(R - R2) / g.norm2(R))
1 change: 1 addition & 0 deletions lib/gpt/algorithms/optimize/non_linear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gpt.algorithms import base_iterative
from gpt.algorithms.optimize import line_search_quadratic


def fletcher_reeves(d, d_last):
ip_dd = g.group.inner_product(d, d)
ip_ll = g.group.inner_product(d_last, d_last)
Expand Down
2 changes: 1 addition & 1 deletion lib/gpt/core/io/gpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, root, write, params):

# escape paths
if self.params["paths"] is not None:
replace = str.maketrans({ "[" : "[[]", "]" : "[]]" })
replace = str.maketrans({"[": "[[]", "]": "[]]"})
self.params["paths"] = [p.translate(replace) for p in self.params["paths"]]

if gpt.rank() == 0:
Expand Down
6 changes: 3 additions & 3 deletions lib/gpt/core/local_stencil/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(self, lat, points, code, segments, local=1):
self.osites_per_cache_block = lat.grid.gsites

def __call__(self, *fields):
cgpt.stencil_tensor_execute(self.obj, list(fields),
self.osites_per_instruction,
self.osites_per_cache_block)
cgpt.stencil_tensor_execute(
self.obj, list(fields), self.osites_per_instruction, self.osites_per_cache_block
)

def __del__(self):
cgpt.stencil_tensor_delete(self.obj)
Expand Down
4 changes: 3 additions & 1 deletion lib/gpt/ml/layer/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@


class nearest_neighbor(cshift):
def __init__(self, grid, ot_input=g.ot_singlet(), ot_weights=g.ot_singlet(), activation=sigmoid):
def __init__(
self, grid, ot_input=g.ot_singlet(), ot_weights=g.ot_singlet(), activation=sigmoid
):
nd = grid.nd
super().__init__(
grid,
Expand Down
32 changes: 8 additions & 24 deletions lib/gpt/qcd/fermion/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,24 @@ def register(reg, op):
reg.Mdiag = lambda dst, src: op.apply_unary_operator(2009, dst, src)
reg.Dminus = lambda dst, src: op.apply_unary_operator(2010, dst, src)
reg.DminusDag = lambda dst, src: op.apply_unary_operator(2011, dst, src)
reg.ImportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(
2012, dst, src
)
reg.ImportUnphysicalFermion = lambda dst, src: op.apply_unary_operator(
2013, dst, src
)
reg.ExportPhysicalFermionSolution = lambda dst, src: op.apply_unary_operator(
2014, dst, src
)
reg.ExportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(
2015, dst, src
)
reg.ImportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(2012, dst, src)
reg.ImportUnphysicalFermion = lambda dst, src: op.apply_unary_operator(2013, dst, src)
reg.ExportPhysicalFermionSolution = lambda dst, src: op.apply_unary_operator(2014, dst, src)
reg.ExportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(2015, dst, src)
reg.Dhop = lambda dst, src: op.apply_unary_operator(3001, dst, src)
reg.DhopDag = lambda dst, src: op.apply_unary_operator(4001, dst, src)
reg.DhopEO = lambda dst, src: op.apply_unary_operator(3002, dst, src)
reg.DhopEODag = lambda dst, src: op.apply_unary_operator(4002, dst, src)
reg.Mdir = lambda dst, src, dir, disp: op.apply_dirdisp_operator(
5001, dst, src, dir, disp
)
reg.Mdir = lambda dst, src, dir, disp: op.apply_dirdisp_operator(5001, dst, src, dir, disp)
reg.MDeriv = lambda mat, dst, src: op.apply_deriv_operator(6001, mat, dst, src)
reg.MDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7001, mat, dst, src)
reg.MoeDeriv = lambda mat, dst, src: op.apply_deriv_operator(6002, mat, dst, src)
reg.MoeDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7002, mat, dst, src)
reg.MeoDeriv = lambda mat, dst, src: op.apply_deriv_operator(6003, mat, dst, src)
reg.MeoDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7003, mat, dst, src)
reg.DhopDeriv = lambda mat, dst, src: op.apply_deriv_operator(6004, mat, dst, src)
reg.DhopDerivDag = lambda mat, dst, src: op.apply_deriv_operator(
7004, mat, dst, src
)
reg.DhopDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7004, mat, dst, src)
reg.DhopDerivEO = lambda mat, dst, src: op.apply_deriv_operator(6005, mat, dst, src)
reg.DhopDerivEODag = lambda mat, dst, src: op.apply_deriv_operator(
7005, mat, dst, src
)
reg.DhopDerivEODag = lambda mat, dst, src: op.apply_deriv_operator(7005, mat, dst, src)
reg.DhopDerivOE = lambda mat, dst, src: op.apply_deriv_operator(6006, mat, dst, src)
reg.DhopDerivOEDag = lambda mat, dst, src: op.apply_deriv_operator(
7006, mat, dst, src
)
reg.DhopDerivOEDag = lambda mat, dst, src: op.apply_deriv_operator(7006, mat, dst, src)
15 changes: 7 additions & 8 deletions lib/gpt/qcd/scalar/action/mass_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __call__(self, pi):
return g.group.inner_product(pi, pi) * self.m * 0.5

def draw(self, pi, rng):
rng.normal_element(pi, scale = self.m**0.5)
rng.normal_element(pi, scale=self.m**0.5)
return self.__call__(pi)

@differentiable_functional.multi_field_gradient
Expand All @@ -46,11 +46,10 @@ def gradient(self, pi, dpi):
class fourier_mass_term(differentiable_functional):
def __init__(self, fourier_sqrt_mass_field):
self.n = len(fourier_sqrt_mass_field)

self.fourier_sqrt_mass_field = fourier_sqrt_mass_field
self.fourier_mass_field = [
[g.lattice(fourier_sqrt_mass_field[0][0]) for i in range(self.n)]
for j in range(self.n)
[g.lattice(fourier_sqrt_mass_field[0][0]) for i in range(self.n)] for j in range(self.n)
]
for i in range(self.n):
for j in range(self.n):
Expand Down Expand Up @@ -84,7 +83,7 @@ def draw(self, pi, rng):
pi[mu] @= g.inv(self.fft) * pi[mu]
pi[mu] @= g(0.5 * (pi[mu] + g.adj(pi[mu])))
return self.__call__(pi)

@differentiable_functional.multi_field_gradient
def gradient(self, pi, dpi):
dS = []
Expand All @@ -93,13 +92,13 @@ def gradient(self, pi, dpi):
mu = pi.index(_pi)
ret = g.lattice(pi[mu])
ret[:] = 0

for nu in range(self.n):
ret += self.fourier_mass_field[mu][nu] * fft_pi[nu]

ret @= g.inv(self.fft) * ret

ret = g(0.5 * (ret + g.adj(ret)))
dS.append(ret)

return dS
27 changes: 27 additions & 0 deletions tests/qcd/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,30 @@
for a in actions:
g.message(a.__name__)
a.assert_gradient_error(rng, phi, phi, 1e-5, 1e-7)

# fourier mass term
U_mom = [g.lattice(grid, g.ot_matrix_su_n_fundamental_algebra(3)) for i in range(4)]
rng.element(U_mom)
A0 = g.qcd.scalar.action.mass_term(0.612)
sqrt_mass = [[g.complex(grid) for i in range(4)] for j in range(4)]
for mu in range(4):
for nu in range(4):
sqrt_mass[mu][nu][:] = 0.612**0.5 if mu == nu else 0.0

A1 = g.qcd.scalar.action.fourier_mass_term(sqrt_mass)

eps = abs(A1(U_mom) / A0(U_mom) - 1.0)
g.message(f"Regress fourier mass term: {eps}")
assert eps < 1e-10

# now test with general Hermitian mass matrix
for mu in range(4):
for nu in range(4):
rng.cnormal(sqrt_mass[mu][nu])
tmp = [g.copy(x) for x in sqrt_mass]
for mu in range(4):
for nu in range(4):
sqrt_mass[mu][nu] @= g.adj(tmp[nu][mu]) + tmp[mu][nu]

A1 = g.qcd.scalar.action.fourier_mass_term(sqrt_mass)
A1.assert_gradient_error(rng, U_mom, U_mom, 1e-3, 1e-8)

0 comments on commit 426854f

Please sign in to comment.