Skip to content

Commit

Permalink
first version of einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Oct 26, 2023
1 parent 9df419d commit cabcfaa
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 22 deletions.
15 changes: 10 additions & 5 deletions lib/cgpt/lib/stencil/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class cgpt_stencil_tensor : public cgpt_stencil_tensor_base {
code[i].element = _code[i].element;
code[i].instruction = _code[i].instruction;
code[i].weight = _code[i].weight;
code[i].size = (int)_code[i].factor.size();
code[i].size = (uint32_t)_code[i].factor.size();
if (code[i].size == 0)
ERR("Cannot create empty factor");
code[i].factor = &factors[nfactors];
Expand Down Expand Up @@ -217,10 +217,15 @@ class cgpt_stencil_tensor : public cgpt_stencil_tensor_base {
}

#define KERNEL_BIN(signature, op, functor, NN) { \
auto bNN = _f1->stride; \
element_t* __restrict__ e_b = ((element_t*)_f1->base_ptr) + bNN * NN * MAP_INDEX(_f1,ss) + lane; \
for (int ff=0;ff<NN;ff++) \
e_c[cNN * ff] signature functor(e_a[aNN * ff]) op e_b[bNN * ff]; \
if (_p->size == 2) { \
auto bNN = _f1->stride; \
element_t* __restrict__ e_b = ((element_t*)_f1->base_ptr) + bNN * NN * MAP_INDEX(_f1,ss) + lane; \
for (int ff=0;ff<NN;ff++) \
e_c[cNN * ff] signature functor(e_a[aNN * ff]) op e_b[bNN * ff]; \
} else { \
for (int ff=0;ff<NN;ff++) \
e_c[cNN * ff] signature functor(e_a[aNN * ff]); \
} \
}


Expand Down
1 change: 1 addition & 0 deletions lib/gpt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@
import gpt.core.local_stencil
from gpt.core.padding import padded_local_fields
import gpt.core.stencil
from gpt.core.einsum import einsum
178 changes: 178 additions & 0 deletions lib/gpt/core/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#
# GPT - Grid Python Toolkit
# Copyright (C) 2023 Christoph Lehner ([email protected], https://github.com/lehner/gpt)
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt as g


def einsum(contraction, *tensors):
contraction = contraction.split("->")
if len(contraction) != 2:
raise Exception(f"{contraction} needs to be explicit, i.e., of the form ...->...")
source, destination = contraction
source = [[x for x in s] for s in source.split(",")]
destination = [[x for x in s] for s in destination.split(",")]
if len(tensors) != len(source) + len(destination):
raise Exception(f"Need {len(source)} source and {len(destination)} destination tensors")
tensors_source = tensors[0 : len(source)]
tensors_destination = tensors[len(source) :]

# now infer and verify index dimensions
index_dimension = {}
epsilon_indices = {}
epsilon_tensors = []
source_indices = {}
destination_indices = {}
for indices, tensors, all_indices in [
(source, tensors_source, source_indices),
(destination, tensors_destination, destination_indices),
]:
for i in range(len(indices)):
if tensors[i] is g.epsilon:
dim = len(indices[i])
epsilon_tensors.append(indices[i])
for s in indices[i]:
all_indices[s] = True
epsilon_indices[s] = True
if s in index_dimension:
if index_dimension[s] != dim:
raise Exception(f"Index {s} already defined to have dimension {dim}")
else:
index_dimension[s] = dim
else:
shape = tensors[i].otype.shape
if shape == (1,):
shape = tuple()
if len(shape) != len(indices[i]):
raise Exception(
f"Tensor {i} is expected to have {len(shape)} indices instead of {len(indices[i])}"
)
for j in range(len(shape)):
dim = shape[j]
s = indices[i][j]
all_indices[s] = True
if s in index_dimension:
if index_dimension[s] != dim:
raise Exception(f"Index {s} already defined to have dimension {dim}")
else:
index_dimension[s] = dim
# print(index_dimension)
# now go through all indices
indices = list(index_dimension.keys())
full_indices = [i for i in destination_indices if i not in epsilon_indices]
nsegment = 1
for i in full_indices:
nsegment *= index_dimension[i]
for i in source_indices:
if i not in epsilon_indices and i not in full_indices:
full_indices.append(i)
index_value = [0] * len(full_indices)

code = []

def get_element(indices, names, values):
element = 0
for i in indices:
element = element * index_dimension[i] + values[names.index(i)]
return element

acc = {}
ti = g.stencil.tensor_instructions

def process(names, values, sign):
# for now can only do single destination tensor and two source tensor lattices
assert len(destination) == 1
c = destination[0]
sidx = []
for i in range(len(source)):
if tensors_source[i] is not g.epsilon:
sidx.append(source[i])

# get destination index
c_element = get_element(c, names, values)

if len(sidx) == 2:
a_element = get_element(sidx[0], names, values)
b_element = get_element(sidx[1], names, values)

if c_element not in acc:
acc[c_element] = True
mode = ti.mov if sign > 0 else ti.mov_neg
else:
mode = ti.inc if sign > 0 else ti.dec
code.append((0, c_element, mode, 1.0, [(1, 0, a_element), (2, 0, b_element)]))

elif len(sidx) == 1:
a_element = get_element(sidx[0], names, values)
if c_element not in acc:
acc[c_element] = True
mode = ti.mov if sign > 0 else ti.mov_neg
else:
mode = ti.inc if sign > 0 else ti.dec
code.append((0, c_element, mode, 1.0, [(1, 0, a_element)]))

else:
raise Exception(
"General einsum case not yet implemented; limited to contraction of one or two tensors"
)

def process_indices(names, values, epsilon_tensors, sign0):
if len(epsilon_tensors) == 0:
process(names, values, sign0)
else:
n = len(epsilon_tensors[0])
eps = g.epsilon(n)
for i, sign in eps:
keep = True
for j in range(n):
idx = epsilon_tensors[0][j]
if idx in names and values[names.index(idx)] != i[j]:
keep = False
break
if keep:
names_next = [n for n in names]
values_next = [v for v in values]
for j in range(n):
idx = epsilon_tensors[0][j]
if idx not in names:
names_next.append(idx)
values_next.append(i[j])
process_indices(names_next, values_next, epsilon_tensors[1:], sign * sign0)

active = True
while active:
process_indices(full_indices, index_value, epsilon_tensors, 1)
for j in range(len(index_value)):
if index_value[j] + 1 < index_dimension[full_indices[j]]:
index_value[j] += 1
break
elif j == len(index_value) - 1:
active = False
else:
index_value[j] = 0

assert len(code) % nsegment == 0
segments = [(len(code) // nsegment, nsegment)]

ein = g.stencil.tensor(tensors_destination[0], [(0, 0, 0, 0)], code, segments)

def exec(*src):
c = g.lattice(tensors_destination[0])
ein(c, *src)
return c

return exec
16 changes: 8 additions & 8 deletions lib/gpt/qcd/baryon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@


default_cache = {}


def diquark(Q1, Q2, cache=default_cache):
R = g.lattice(Q1)
# D_{a2,a1} = epsilon_{a1,b1,c1}*epsilon_{a2,b2,c2}*spin_transpose(Q1_{b1,b2})*Q2_{c1,c2}
# D_{a2,a1} = epsilon_{a1,b1,c1}*epsilon_{a2,b2,c2}*Q1_{b1,b2}*spin_transpose(Q2_{c1,c2})
cache_key = f"{Q1.otype.__name__}_{Q1.checkerboard().__name__}_{Q1.grid.describe()}"
if cache_key not in cache:
Nc = Q1.otype.shape[2]
Expand All @@ -36,19 +38,17 @@ def diquark(Q1, Q2, cache=default_cache):
for l in range(Ns):
for i1, sign1 in eps:
for i2, sign2 in eps:
dst = (i*Ns + j)*Nc*Nc + i2[0]*Nc + i1[0]
aa = (Ns*i + l)*Nc*Nc + i1[1]*Nc + i2[1]
bb = (Ns*j + l)*Nc*Nc + i1[2]*Nc + i2[2]
dst = (i * Ns + j) * Nc * Nc + i2[0] * Nc + i1[0]
aa = (Ns * i + l) * Nc * Nc + i1[1] * Nc + i2[1]
bb = (Ns * j + l) * Nc * Nc + i1[2] * Nc + i2[2]
if dst not in acc:
acc[dst] = True
mode = ti.mov if sign1 * sign2 > 0 else ti.mov_neg
else:
mode = ti.inc if sign1 * sign2 > 0 else ti.dec
code.append(
(0,dst,mode,1.0,[(1,0,aa),(2,0,bb)])
)
code.append((0, dst, mode, 1.0, [(1, 0, aa), (2, 0, bb)]))

segments = [(len(code) // (Ns*Ns), Ns*Ns)]
segments = [(len(code) // (Ns * Ns), Ns * Ns)]
cache[cache_key] = g.stencil.tensor(Q1, [(0, 0, 0, 0)], code, segments)

cache[cache_key](R, Q1, Q2)
Expand Down
57 changes: 48 additions & 9 deletions tests/core/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,13 @@ def lap(dst, src):
g.message(f"Stencil covariant laplace versus cshift version: {eps2}")
assert eps2 < 1e-25


# tensor stencil test for case of diquark
def serial_diquark(Q1, Q2):
eps = g.epsilon(Q1.otype.shape[2])
R = g.lattice(Q1)

# D_{a2,a1} = epsilon_{a1,b1,c1}*epsilon_{a2,b2,c2}*spin_transpose(Q1_{b1,b2})*Q2_{c1,c2}
# D_{a2,a1} = epsilon_{a1,b1,c1}*epsilon_{a2,b2,c2}*Q1_{b1,b2}*spin_transpose(Q2_{c1,c2})
Q1 = g.separate_color(Q1)
Q2 = g.separate_color(Q2)

Expand All @@ -242,6 +243,7 @@ def serial_diquark(Q1, Q2):
g.merge_color(R, D)
return R


def stencil_diquark(Q1, Q2):
Nc = Q1.otype.shape[2]
Ns = Q1.otype.shape[0]
Expand All @@ -255,26 +257,26 @@ def stencil_diquark(Q1, Q2):
for l in range(Ns):
for i1, sign1 in eps:
for i2, sign2 in eps:
dst = (i*Ns + j)*Nc*Nc + i2[0]*Nc + i1[0]
aa = (Ns*i + l)*Nc*Nc + i1[1]*Nc + i2[1]
bb = (Ns*j + l)*Nc*Nc + i1[2]*Nc + i2[2]
dst = (i * Ns + j) * Nc * Nc + i2[0] * Nc + i1[0]
aa = (Ns * i + l) * Nc * Nc + i1[1] * Nc + i2[1]
bb = (Ns * j + l) * Nc * Nc + i1[2] * Nc + i2[2]
if dst not in acc:
acc[dst] = True
mode = ti.mov if sign1 * sign2 > 0 else ti.mov_neg
else:
mode = ti.inc if sign1 * sign2 > 0 else ti.dec
code.append(
(0,dst,mode,1.0,[(1,0,aa),(2,0,bb)])
)
code.append((0, dst, mode, 1.0, [(1, 0, aa), (2, 0, bb)]))

segments = [(len(code) // (Ns*Ns), Ns*Ns)]
segments = [(len(code) // (Ns * Ns), Ns * Ns)]
ein = g.stencil.tensor(Q1, [(0, 0, 0, 0)], code, segments)
ein(R, Q1, Q2)
return R


Q1 = g.mspincolor(grid)
Q2 = g.mspincolor(grid)
rng.cnormal([Q1,Q2])

rng.cnormal([Q1, Q2])
st_di = stencil_diquark(Q1, Q2)
se_di = serial_diquark(Q1, Q2)
std_di = g.qcd.baryon.diquark(Q1, Q2)
Expand All @@ -285,3 +287,40 @@ def stencil_diquark(Q1, Q2):
eps2 = g.norm2(st_di - std_di) / g.norm2(std_di)
g.message(f"Diquark stencil test (stencil <> g.qcd.gauge.diquark): {eps2}")
assert eps2 < 1e-25

# and use this to test einsum
# D_{a2,a1} = epsilon_{a1,b1,c1}*epsilon_{a2,b2,c2}*Q1_{b1,b2}*spin_transpose(Q2_{c1,c2})
einsum_di = g.einsum("acd,bef,ACce,BCdf->ABba", g.epsilon, g.epsilon, Q1, Q2, Q1)
es_di = einsum_di(Q1, Q2)

eps2 = g.norm2(st_di - es_di) / g.norm2(st_di)
g.message(f"Diquark stencil test (stencil <> einsum): {eps2}")
assert eps2 < 1e-25

einsum_trace = g.einsum("AAaa->", Q1, g.complex(Q1.grid))
xx = einsum_trace(Q1)
yy = g(g.trace(Q1))
eps2 = g.norm2(xx - yy) / g.norm2(yy)
g.message(f"Einsum trace test: {eps2}")
assert eps2 < 1e-25

einsum_spintrace = g.einsum("AAab->ab", Q1, g.mcolor(Q1.grid))
xx = einsum_spintrace(Q1)
yy = g(g.spin_trace(Q1))
eps2 = g.norm2(xx - yy) / g.norm2(yy)
g.message(f"Einsum spintrace test: {eps2}")
assert eps2 < 1e-25

einsum_transpose = g.einsum("ABab->BAba", Q1, Q1)
xx = einsum_transpose(Q1)
yy = g(g.transpose(Q1))
eps2 = g.norm2(xx - yy) / g.norm2(yy)
g.message(f"Einsum transpose test: {eps2}")
assert eps2 < 1e-25

einsum_mm = g.einsum("ABab,BCbc->ACac", Q1, Q1, Q1)
xx = einsum_mm(Q1, Q2)
yy = g(Q1 * Q2)
eps2 = g.norm2(xx - yy) / g.norm2(yy)
g.message(f"Einsum mm test: {eps2}")
assert eps2 < 1e-25

0 comments on commit cabcfaa

Please sign in to comment.