Skip to content

Commit

Permalink
update parallel_transport stencil
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Oct 7, 2023
1 parent 89ef2fe commit db14236
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 81 deletions.
40 changes: 23 additions & 17 deletions benchmarks/stencil_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
import gpt as g

g.default.set_verbose("random", False)
rng = g.random(
"benchmark", "vectorized_ranlux24_24_64"
)
rng = g.random("benchmark", "vectorized_ranlux24_24_64")

for precision in [g.single, g.double]:
grid = g.grid(g.default.get_ivec("--grid", [16, 16, 16, 32], 4), precision)
N = g.default.get_int("--N", 1000)

g.message(
f"""
Local Stencil Benchmark with
Expand All @@ -22,18 +20,31 @@
)

U = g.qcd.gauge.random(grid, rng, scale=0.5)
_U = [1,2,3,4]
_U = [1, 2, 3, 4]
_X = 0
_Xp = [1,2,3,4]
_Xp = [1, 2, 3, 4]
V = g.mcolor(grid)
rng.element(V)
#U = g.qcd.gauge.transformed(U, V)
# U = g.qcd.gauge.transformed(U, V)
code = []
for mu in range(4):
for nu in range(0, mu):
code.append({"target": 0, "accumulate": -1 if (mu == 1 and nu == 0) else 0, "weight": 1.0, "factor":
[(_U[mu], _X, 0),(_U[nu], _Xp[mu], 0),(_U[mu], _Xp[nu], 1),(_U[nu], _X, 1)]})
st = g.local_stencil.matrix(U[0], [(0,0,0,0),(1,0,0,0),(0,1,0,0),(0,0,1,0),(0,0,0,1)], code)
code.append(
{
"target": 0,
"accumulate": -1 if (mu == 1 and nu == 0) else 0,
"weight": 1.0,
"factor": [
(_U[mu], _X, 0),
(_U[nu], _Xp[mu], 0),
(_U[mu], _Xp[nu], 1),
(_U[nu], _X, 1),
],
}
)
st = g.local_stencil.matrix(
U[0], [(0, 0, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)], code
)
# test plaquette (local stencil, so no comms, only agrees on single rank)
P = g.lattice(U[0])
st(P, *U)
Expand All @@ -43,15 +54,10 @@
# Flops
gauge_otype = U[0].otype
Nc = gauge_otype.shape[0]
flops_per_matrix_multiply = Nc**3 * 6 + (Nc-1)*Nc**2 * 2
flops_per_matrix_multiply = Nc**3 * 6 + (Nc - 1) * Nc**2 * 2
flops_per_site = 3 * flops_per_matrix_multiply * 4 * 3
flops = flops_per_site * P.grid.gsites * N
nbytes = (
(5 * Nc * Nc * 2)
* precision.nbytes
* P.grid.gsites
* N
)
nbytes = (5 * Nc * Nc * 2) * precision.nbytes * P.grid.gsites * N

# Warmup
for n in range(5):
Expand Down
42 changes: 6 additions & 36 deletions lib/gpt/core/parallel_transport/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,47 +102,17 @@ def __init__(self, U, code, n_target):

self.code.append((c[0], c[1], c[2], factors))

# create halo margin
margin = [0] * Nd
for p in points.points:
for i in range(Nd):
x = abs(p[i])
if x > margin[i]:
margin[i] = x

self.margin = margin
self.ncode = len(self.code)

# create local stencil and padding; TODO: wrap this in stencil class
self.padding_U = g.padded_local_fields(U, margin)
self.padding_T = g.padded_local_fields([g.lattice(U[0]) for i in range(Ntarget)], margin)
padded_U = self.padding_U(U)
self.local_stencil = g.local_stencil.matrix(padded_U[0], points.points, self.code)
write_fields = list(range(Ntarget))
read_fields = list(range(Ntarget + Ntemporary, Ntarget + Ntemporary + Nd))

self.stencil = g.stencil.matrix(U[0], points.points, write_fields, read_fields, self.code)

def __call__(self, U):
t = g.timer(
f"parallel_transport_matrix(margin={self.margin}, ncode={self.ncode}, ntarget={self.Ntarget}, ntemp={self.Ntemporary})"
)

# halo exchange
t("halo exchange")
padded_U = self.padding_U(U)
padded_Temp = [g.lattice(padded_U[0]) for i in range(self.Ntemporary)]
padded_T = [g.lattice(padded_U[0]) for i in range(self.Ntarget)]

# stencil computation
t("local stencil computation")
self.local_stencil(*padded_T, *padded_Temp, *padded_U)

# get bulk
t("extract bulk")
T = [g.lattice(U[0]) for i in range(self.Ntarget)]
self.padding_T.extract(T, padded_T)

t()

if self.verbose:
g.message(t)
Temp = [g.lattice(U[0]) for i in range(self.Ntemporary)]
self.stencil(*T, *Temp, *U)

if self.Ntarget == 1:
return T[0]
Expand Down
15 changes: 12 additions & 3 deletions lib/gpt/core/stencil/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@
# - SIMD in multi-rhs ? maybe add --simd_mask flag to command line ?
# - should do margins automatically seems best
class matrix:
def __init__(self, lat, points, margin_top, write_fields, read_fields, code, code_parallel_block_size=None):
self.padding = g.padded_local_fields(lat, margin_top)
self.local_stencil = g.local_stencil.matrix(self.padding(lat), points, code, code_parallel_block_size)
def __init__(self, lat, points, write_fields, read_fields, code, code_parallel_block_size=None):
margin = [0] * lat.grid.nd
for p in points:
for i in range(lat.grid.nd):
x = abs(p[i])
if x > margin[i]:
margin[i] = x

self.padding = g.padded_local_fields(lat, margin)
self.local_stencil = g.local_stencil.matrix(
self.padding(lat), points, code, code_parallel_block_size
)
self.write_fields = write_fields
self.read_fields = read_fields

Expand Down
32 changes: 8 additions & 24 deletions lib/gpt/qcd/fermion/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,24 @@ def register(reg, op):
reg.Mdiag = lambda dst, src: op.apply_unary_operator(2009, dst, src)
reg.Dminus = lambda dst, src: op.apply_unary_operator(2010, dst, src)
reg.DminusDag = lambda dst, src: op.apply_unary_operator(2011, dst, src)
reg.ImportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(
2012, dst, src
)
reg.ImportUnphysicalFermion = lambda dst, src: op.apply_unary_operator(
2013, dst, src
)
reg.ExportPhysicalFermionSolution = lambda dst, src: op.apply_unary_operator(
2014, dst, src
)
reg.ExportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(
2015, dst, src
)
reg.ImportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(2012, dst, src)
reg.ImportUnphysicalFermion = lambda dst, src: op.apply_unary_operator(2013, dst, src)
reg.ExportPhysicalFermionSolution = lambda dst, src: op.apply_unary_operator(2014, dst, src)
reg.ExportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(2015, dst, src)
reg.Dhop = lambda dst, src: op.apply_unary_operator(3001, dst, src)
reg.DhopDag = lambda dst, src: op.apply_unary_operator(4001, dst, src)
reg.DhopEO = lambda dst, src: op.apply_unary_operator(3002, dst, src)
reg.DhopEODag = lambda dst, src: op.apply_unary_operator(4002, dst, src)
reg.Mdir = lambda dst, src, dir, disp: op.apply_dirdisp_operator(
5001, dst, src, dir, disp
)
reg.Mdir = lambda dst, src, dir, disp: op.apply_dirdisp_operator(5001, dst, src, dir, disp)
reg.MDeriv = lambda mat, dst, src: op.apply_deriv_operator(6001, mat, dst, src)
reg.MDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7001, mat, dst, src)
reg.MoeDeriv = lambda mat, dst, src: op.apply_deriv_operator(6002, mat, dst, src)
reg.MoeDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7002, mat, dst, src)
reg.MeoDeriv = lambda mat, dst, src: op.apply_deriv_operator(6003, mat, dst, src)
reg.MeoDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7003, mat, dst, src)
reg.DhopDeriv = lambda mat, dst, src: op.apply_deriv_operator(6004, mat, dst, src)
reg.DhopDerivDag = lambda mat, dst, src: op.apply_deriv_operator(
7004, mat, dst, src
)
reg.DhopDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7004, mat, dst, src)
reg.DhopDerivEO = lambda mat, dst, src: op.apply_deriv_operator(6005, mat, dst, src)
reg.DhopDerivEODag = lambda mat, dst, src: op.apply_deriv_operator(
7005, mat, dst, src
)
reg.DhopDerivEODag = lambda mat, dst, src: op.apply_deriv_operator(7005, mat, dst, src)
reg.DhopDerivOE = lambda mat, dst, src: op.apply_deriv_operator(6006, mat, dst, src)
reg.DhopDerivOEDag = lambda mat, dst, src: op.apply_deriv_operator(
7006, mat, dst, src
)
reg.DhopDerivOEDag = lambda mat, dst, src: op.apply_deriv_operator(7006, mat, dst, src)
4 changes: 3 additions & 1 deletion tests/core/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
def stencil_cshift(src, direction):
stencil = g.stencil.matrix(
src,
[direction], [1,2,1,1], [0], [1],
[direction],
[0],
[1],
[{"target": 0, "accumulate": -1, "weight": 1.0, "factor": [(1, 0, 0)]}],
)

Expand Down

0 comments on commit db14236

Please sign in to comment.