Skip to content

Commit

Permalink
update hmc
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Oct 25, 2024
1 parent 9808e32 commit 2407a09
Showing 1 changed file with 60 additions and 121 deletions.
181 changes: 60 additions & 121 deletions applications/hmc/dwf/ensembleK.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

# cold start
# target 128x128x128x288
#U = g.qcd.gauge.unit(g.grid([128, 128, 128, 288], g.double)) TODO

U = g.qcd.gauge.random(g.grid([4,4,4,4], g.double), rng)
U = g.qcd.gauge.unit(g.grid([128, 128, 128, 288], g.double))
#U = g.qcd.gauge.unit(g.grid([64,64,64,96], g.double))

latest_it = None
it0 = 0
Expand Down Expand Up @@ -43,8 +42,6 @@

eps2 = g.norm2(g.matrix.det(U[mu]) - g.identity(g.complex(U[0].grid))) / U[0].grid.gsites
g.message("Determinant defect for mu=",mu,"is",eps2)
else:
latest_it = -1

if False:
Lold = U[0].grid.gdimensions
Expand Down Expand Up @@ -498,17 +495,17 @@ def log_det_force_sp():
)


no_accept_reject = True
#no_accept_reject = True
no_accept_reject = False


nsteps = 40
tau = 8.0
nsteps = 40

tau = 12.0
nsteps = 4 # TODO: modify

def refresh_momentum():
global ckp
def hmc(tau):
global ff_iterator, ckp
ff_iterator = 0
accrej = metro(U)
g.message("After metro")
params = U_mom + [x[-1] for x in fields] + [0.0, 0.0]
if not ckp.load(params):
h0, s0 = hamiltonian(True)
Expand All @@ -520,140 +517,79 @@ def refresh_momentum():
sys.exit(0)
else:
h0, s0 = params[-2:]
return h0, s0

def evolve(tau_local, nsteps_local):
global ckp

if not ckp.load(U_mom + U):
its0 = nsteps_local - 1
while its0 >= 0:
g.message(f"Try to load fields after iteration {its0}")
if load_cfields(f"{its0}", U_mom + U):
break
its0 -= 1
its0 += 1
for its in range(its0, nsteps_local):
g.message(f"tau-iteration: {its} -> {tau/nsteps_local*its}")
mdint(tau_local / nsteps_local)
store_css()
store_cfields(f"{its}", U_mom + U)

g.message("After H(true)",h0,s0)
its0 = nsteps - 1
while its0 >= 0:
g.message(f"Try to load fields after iteration {its0}")
if load_cfields(f"{its0}", params + U):
break
its0 -= 1
its0 += 1
nrun = 0
for its in range(its0, nsteps):
g.message(f"tau-iteration: {its} -> {tau/nsteps*its}")
mdint(tau / nsteps)
store_cfields(f"{its}", params + U)
store_css()
nrun += 1
if nrun >= 1:
if its % (nsteps//2) == 0:
h1, s1 = hamiltonian(False)
g.message(f"dH = {h1-h0}")
else:
h1 = None
if g.rank() == 0:
flog = open(f"{dst}/current.log.{its}","wt")
if h1 is not None:
flog.write(f"dH_{its} = {h1} - {h0} = {h1-h0}\n")
for x in log.grad:
flog.write(f"{x} force norm2/sites = {np.mean(log.get(x))} +- {np.std(log.get(x))}\n")
flog.write(f"Timing:\n{log.time}\n")
flog.close()

if its < nsteps_local - 1:
g.barrier()
sys.exit(0)

ckp.save(U_mom + U)

g.barrier()

# reset checkpoint
g.message("Reset evolve checkpoints")
if g.rank() == 0:
for it in range(nsteps_local):
if os.path.exists(f"{dst}/checkpoint.{it}"):
shutil.rmtree(f"{dst}/checkpoint.{it}")

g.barrier()


def metropolis_evolve(tau_local, nsteps_local, h0, s0, status):
global ckp

accrej = metro(U_mom + U)

evolve(tau_local, nsteps_local)
g.barrier()
sys.exit(0)

params = [0.0, 0.0]
if not ckp.load(params):
h1, s1 = hamiltonian(False)
g.message(f"dH = {h1-h0}")
ckp.save([h1,s1])
sys.exit(0)
else:
h1, s1 = params

g.message(f"dH = {h1-h0}")
status.append(h1-h0)
status.append(s1-s0)

if not no_accept_reject:
if not accrej(h1, h0):
# restore to zero state
g.message("Reject update with dH=",h1-h0)
# flip momenta
for um in U_mom:
um *= -1
h1, s1 = h0, s0
acc = False
else:
g.message("Accept update with dH=",h1-h0)
acc = True
else:
acc = True

status.append(acc)

g.message(f"Plaquette after evolution with accept/reject is {g.qcd.gauge.plaquette(U)}")

return h1, s1, acc


def hmc(tau):
global ff_iterator, ckp
ff_iterator = 0
g.message("After metro")

h0, s0 = refresh_momentum()

g.message("After H(true)",h0,s0)

status = []
h1, s1, acc1 = metropolis_evolve(tau/2, nsteps//2, h0, s0, status)

g.message("After first evolution")
h2, s2, acc2 = metropolis_evolve(tau/2, nsteps//2, h1, s1, status)

g.message("After second evolution")

g.message("After mdint(tau)")
h1, s1 = hamiltonian(False)
g.message("After H(false)")
store_css()
if no_accept_reject:
return [True, s1 - s0, h1 - h0]
else:
return [accrej(h1, h0), s1 - s0, h1 - h0]

return status

accept, total = 0, 0
for it in range(it0, N):
pure_gauge = it < 100
no_accept_reject = it < 1 # TODO modify
pure_gauge = it < 10
no_accept_reject = it < 1000
g.message(pure_gauge, no_accept_reject)

dH_1, dS_1, acc_1, dH_2, dS_2, acc_2 = hmc(tau)
a, dS, dH = hmc(tau)
accept += a
total += 1


Uft = U
for s in reversed(sm):
Uft = s(Uft)

plaq = g.qcd.gauge.plaquette(U)
plaqft = g.qcd.gauge.plaquette(Uft)
g.message(f"HMC {it} has P = {plaqft}, Pft = {plaq}, dS_1 = {dS_1}, dS_2 = {dS_2}, dH_1 = {dH_1}, dH_2 = {dH_2}, acceptance = {acc_1} {acc_2}")
#for x in log.grad:
# g.message(f"{x} force norm2/sites =", np.mean(log.get(x)), "+-", np.std(log.get(x)))
#g.message(f"Timing:\n{log.time}")
g.message(f"HMC {it} has P = {plaqft}, Pft = {plaq}, dS = {dS}, dH = {dH}, acceptance = {accept/total}")
for x in log.grad:
g.message(f"{x} force norm2/sites =", np.mean(log.get(x)), "+-", np.std(log.get(x)))
g.message(f"Timing:\n{log.time}")

if g.rank() == 0:
flog = open(f"{dst}/ckpoint_lat.{it}.log","wt")
flog.write(f"dH_1 {dH_1}\n")
flog.write(f"dH_2 {dH_2}\n")
flog.write(f"accept_1 {acc_1}\n")
flog.write(f"accept_2 {acc_2}\n")
flog.write(f"dH {dH}\n")
flog.write(f"P {plaqft}\n")
flog.write(f"Pft {plaq}\n")
for x in log.grad:
flog.write(f"{x} force norm2/sites = {np.mean(log.get(x))} +- {np.std(log.get(x))}\n")
flog.write(f"Timing:\n{log.time}\n")

flog.close()

if it % 10 == 0:
Expand All @@ -672,6 +608,9 @@ def hmc(tau):
# reset checkpoint
if g.rank() == 0:
shutil.rmtree(f"{dst}/checkpoint2")
for it in range(nsteps):
if os.path.exists(f"{dst}/checkpoint.{it}"):
shutil.rmtree(f"{dst}/checkpoint.{it}")

#rng = g.random(f"new{dst}-{it}", "vectorized_ranlux24_24_64")

Expand Down

0 comments on commit 2407a09

Please sign in to comment.