Skip to content

Commit

Permalink
lumi-g tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Sep 21, 2024
1 parent 34ee5a7 commit d3563a1
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 26 deletions.
4 changes: 2 additions & 2 deletions lib/cgpt/lib/distribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ class global_memory_transfer : public global_transfer<rank_t> {
} else if (view.type == mt_host) {
acceleratorFreeCpu(view.ptr);
}

}
};

Expand All @@ -263,7 +262,8 @@ class global_memory_transfer : public global_transfer<rank_t> {
}

// memory buffers
std::map<rank_t, memory_buffer> send_buffers, recv_buffers;
std::vector<memory_buffer> buffers;
std::map<rank_t, memory_view> send_buffers, recv_buffers;
memory_type comm_buffers_type;
std::map<rank_t, std::map< index_t, blocks_t > > send_blocks, recv_blocks;

Expand Down
43 changes: 29 additions & 14 deletions lib/cgpt/lib/distribute/global_memory_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ void global_memory_transfer<offset_t,rank_t,index_t>::create(const view_t& _dst,
if (!local_only) {

Timer("create_com_buffers");
// optionally create communication buffers
// optionalrly create communication buffers
create_comm_buffers(use_comm_buffers_of_type);
}

Expand Down Expand Up @@ -612,17 +612,32 @@ void global_memory_transfer<offset_t,rank_t,index_t>::create_comm_buffers(memory
}

// allocate buffers
#define BUFFER_ALIGN 4096
#define BUFFER_ROUNDUP(a) (((size_t)((a + BUFFER_ALIGN - 1) / BUFFER_ALIGN)) * BUFFER_ALIGN)

size_t sz_total = 0;
for (auto & s : send_size)
sz_total += BUFFER_ROUNDUP(s.second);
for (auto & s : recv_size)
sz_total += BUFFER_ROUNDUP(s.second);

// std::cout << GridLogMessage << "Allocate memory buffer of size " << sz_total << std::endl;
ASSERT(buffers.size() == 0);
buffers.push_back(memory_buffer(sz_total, mt));

sz_total = 0;
char* base = (char*)buffers[0].view.ptr;
for (auto & s : send_size) {
//printf("Rank %d has a send_buffer of size %d for rank %d\n",
// this->rank, (int)s.second, (int)s.first);
send_buffers.insert(std::make_pair(s.first,memory_buffer(s.second, mt)));
memory_view mv = {mt,(void*)(base + sz_total),s.second};
send_buffers.insert(std::make_pair(s.first, mv));
sz_total += BUFFER_ROUNDUP(s.second);
}

for (auto & s : recv_size) {
//printf("Rank %d has a recv_buffer of size %d for rank %d\n",
// this->rank, (int)s.second, (int)s.first);
recv_buffers.insert(std::make_pair(s.first,memory_buffer(s.second, mt)));
memory_view mv = {mt,(void*)(base + sz_total),s.second};
recv_buffers.insert(std::make_pair(s.first, mv));
sz_total += BUFFER_ROUNDUP(s.second);
}

}

template<typename offset_t, typename rank_t, typename index_t>
Expand Down Expand Up @@ -775,7 +790,7 @@ void global_memory_transfer<offset_t,rank_t,index_t>::execute(std::vector<memory
// if there is a buffer, first gather in communication buffer
for (auto & ranks : send_blocks) {
rank_t dst_rank = ranks.first;
auto & dst = send_buffers.at(dst_rank).view;
auto & dst = send_buffers.at(dst_rank);

for (auto & indices : ranks.second) {
index_t src_idx = indices.first;
Expand All @@ -787,14 +802,14 @@ void global_memory_transfer<offset_t,rank_t,index_t>::execute(std::vector<memory

// send/recv buffers
for (auto & buf : send_buffers) {
this->isend(buf.first, buf.second.view.ptr, buf.second.view.sz);
this->isend(buf.first, buf.second.ptr, buf.second.sz);
stats_isends += 1;
stats_send_bytes += buf.second.view.sz;
stats_send_bytes += buf.second.sz;
}
for (auto & buf : recv_buffers) {
this->irecv(buf.first, buf.second.view.ptr, buf.second.view.sz);
this->irecv(buf.first, buf.second.ptr, buf.second.sz);
stats_irecvs += 1;
stats_recv_bytes += buf.second.view.sz;
stats_recv_bytes += buf.second.sz;
}
}

Expand Down Expand Up @@ -835,7 +850,7 @@ void global_memory_transfer<offset_t,rank_t,index_t>::execute(std::vector<memory
for (auto & ranks : recv_blocks) {

rank_t src_rank = ranks.first;
auto & src = recv_buffers.at(src_rank).view;
auto & src = recv_buffers.at(src_rank);

for (auto & indices : ranks.second) {
index_t dst_idx = indices.first;
Expand Down
4 changes: 2 additions & 2 deletions lib/gpt/core/domain/two_grid_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def lattice(self, otype):
def project(self, dst, src):
dst = g.util.to_list(dst)
src = g.util.to_list(src)
tag = str([s.otype.__name__ for s in src])
tag = str([(s.otype.__name__, str(s.grid)) for s in src])
if tag not in self.project_plan:
plan = g.copy_plan(dst, src, embed_in_communicator=src[0].grid)
assert len(dst) == len(src)
Expand All @@ -44,7 +44,7 @@ def project(self, dst, src):
def promote(self, dst, src):
dst = g.util.to_list(dst)
src = g.util.to_list(src)
tag = str([s.otype.__name__ for s in src])
tag = str([(s.otype.__name__, str(s.grid)) for s in src])
if tag not in self.promote_plan:
plan = g.copy_plan(dst, src, embed_in_communicator=dst[0].grid)
assert len(dst) == len(src)
Expand Down
9 changes: 7 additions & 2 deletions lib/gpt/core/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
#
import gpt as g

default_padding_cache = {}


class padded_local_fields:
def __init__(self, fields, margin_top, margin_bottom=None):
def __init__(self, fields, margin_top, margin_bottom=None, cache=default_padding_cache):
fields = g.util.to_list(fields)
self.grid = fields[0].grid
self.otype = fields[0].otype
Expand All @@ -29,7 +31,10 @@ def __init__(self, fields, margin_top, margin_bottom=None):
assert all([f.otype.__name__ == self.otype.__name__ for f in fields])
assert all([f.grid.obj == self.grid.obj for f in fields])

self.domain = g.domain.local(self.grid, margin_top, margin_bottom)
tag = f"{self.grid}_{margin_top}_{margin_bottom}"
if tag not in cache:
cache[tag] = g.domain.local(self.grid, margin_top, margin_bottom)
self.domain = cache[tag]

def __call__(self, fields):
return_list = isinstance(fields, list)
Expand Down
40 changes: 34 additions & 6 deletions lib/gpt/qcd/gauge/smear/local_stout.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def compute_adj_abc(_A, _B, _C, _V, generators, cache, parity):
adjoint_from_right_fast(D, UtaU, generators, cache)

t("other")
tmp2[a,] = g(g.trace(C * D))
tmp2[
a,
] = g(g.trace(C * D))
t("merge")
g.merge_color(V, tmp2)
t("checkerboarding")
Expand Down Expand Up @@ -149,7 +151,13 @@ def adjoint_to_fundamental(fund, adj, generators):
fund[:] = 0
adj_c = g.separate_color(adj)
for e in range(ng):
fund += 1j * adj_c[e,] * generators[e]
fund += (
1j
* adj_c[
e,
]
* generators[e]
)


class local_stout(local_diffeomorphism):
Expand Down Expand Up @@ -188,7 +196,18 @@ def get_C(self, fields):
mask, imask = masks[self.params["checkerboard"]], masks[self.params["checkerboard"].inv()]

fm = g(mask + 1e-15 * imask)
st = g.qcd.gauge.staple_sum(U, mu=self.params["dimension"], rho=rho)[0]
if False:
st = g.qcd.gauge.staple_sum(U, mu=self.params["dimension"], rho=rho)[0]
else:
st = g.lattice(U[0])
st[:] = 0
for nu in range(len(U)):
if nu == self.params["dimension"]:
continue
st += self.params["rho"] * g.qcd.gauge.staple(U, self.params["dimension"], nu)
# stref = g.qcd.gauge.staple_sum(U, mu=self.params["dimension"], rho=rho)[0]
# g.message("TEST", g.norm2(st), g.norm2(st-stref))
# sys.exit(0)
sf = self.params["staple_field"]
if sf is not None:
st = sf(st)
Expand Down Expand Up @@ -385,6 +404,7 @@ class local_stout_action_log_det_jacobian(differentiable_functional):
def __init__(self, stout):
self.stout = stout
self.verbose = stout.verbose
self.cache = {}

def plaquette_stencil(self, U, rho, mu, nu):
key = f"{U[0].grid.describe()}_{U[0].otype.__name__}_{mu}_{nu}_{rho}"
Expand Down Expand Up @@ -412,12 +432,19 @@ def __call__(self, U):
def gradient(self, U, dU):
assert dU == U

cache_key = f"{U[0].grid.describe()}_{U[0].otype.__name__}"

t = g.timer("action_log_det_jacobian")

cb = self.stout.params["checkerboard"]

t("jac_comp")
cache_ab = {}
if cache_key not in self.cache:
self.cache[cache_key] = {"ab": {}, "gen": {}}

cache_ab = self.cache[cache_key]["ab"]
cache = self.cache[cache_key]["gen"]

J_ac, NxxAd, Z_ac, M, fm, M_ab = self.stout.jacobian_components(U, cache_ab)

grid = J_ac.grid
Expand Down Expand Up @@ -461,7 +488,6 @@ def gradient(self, U, dU):
PlaqL = g.identity(U[0])
PlaqR = g(M * fm)
FdetV = g.lattice(grid, adjoint_vector_otype)
cache = {}

compute_adj_abc(PlaqL, PlaqR, MpInvJx, FdetV, generators, cache, cb)

Expand All @@ -470,7 +496,9 @@ def gradient(self, U, dU):

tmp = {}
for e in range(ng):
tmp[e,] = g(g.trace(dJdX[e] * nMpInv))
tmp[
e,
] = g(g.trace(dJdX[e] * nMpInv))
dJdXe_nMpInv = g.lattice(grid, adjoint_vector_otype)
g.merge_color(dJdXe_nMpInv, tmp)
dJdXe_nMpInv @= dJdXe_nMpInv * fm
Expand Down
1 change: 1 addition & 0 deletions lib/gpt/qcd/gauge/stencil/staple.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def staple_sum(U, rho, mu=None, cache=default_staple_cache):
nwr[idx] = 1

cache[tag] = g.parallel_transport_matrix(U, code, Ntarget)
# g.message("NEW STAPLE SUM", tag)

T = cache[tag](U)

Expand Down

0 comments on commit d3563a1

Please sign in to comment.