Skip to content

Commit

Permalink
Merge pull request #8 from QC-UCI/QCAAN
Browse files Browse the repository at this point in the history
QC-AAN
  • Loading branch information
TDHTTTT authored Feb 26, 2021
2 parents 5a6df7f + a04660a commit ff7fe73
Show file tree
Hide file tree
Showing 45 changed files with 82 additions and 11 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
15 changes: 15 additions & 0 deletions qcbm_train.py → QC-CaloGAN/models/qcbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ def initialize_weights(layers, num_wires):
a[l].append([np.random.random()*np.pi*2 for _ in range(num_wires-1-i)])
return np.array(a)


def train_qcbm(exact_prob_dist, weights):
"""Train the QCBM"""
exact_prob_dict = {outcome:exact_prob_dist[outcome] for outcome in range(2**num_wires)}
def approx_cost_fn(weights):
return KL_Loss_dict(exact_prob_dict, qcbm_approx_probs(weights, num_wires))

for i in range(1000):
weights = weights - 0.01* SPSA_grad(approx_cost_fn, weights) #cost using approx sample probabilities
#weights = weights - 0.01* exact_grad_cost(weights) #cost using exact sample probabilities
if i % 100 == 0:
#print("Approx Cost:", KL_Loss_dict(exact_prob_dict, qcbm_approx_probs(weights, num_wires)))
print("True Cost:", KL_Loss(exact_prob_dist, qcbm_probs(weights, num_wires)))
return weights

#########################
#########################
if __name__ == "__main__":
Expand Down
78 changes: 67 additions & 11 deletions CaloGAN/models/train.py → QC-CaloGAN/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sklearn.utils import shuffle
import sys
import yaml
import h5py


if __name__ == '__main__':
Expand Down Expand Up @@ -56,7 +57,19 @@ def get_parser():
help='batch size per update')

parser.add_argument('--latent-size', action='store', type=int, default=1024,
help='size of random N(0, 1) latent space to sample')
help='size of classical prior from N(0,1) to sample')

parser.add_argument('--nb-qubits', action='store', type=int, default=8,
help='number of qubits to use for QCBM')

parser.add_argument('--qcbm-nb-layer', action='store', type=int, default=7,
help='number of layers for QCBM ansatz')

parser.add_argument('--qcbm-nb-shots', action='store', type=int, default=20000,
help='number of shots for QCBM')

parser.add_argument('--nb-samples', action='store', type=int, default=-1,
help='number of samples to train')

parser.add_argument('--disc-lr', action='store', type=float, default=2e-5,
help='Adam learning rate for discriminator')
Expand Down Expand Up @@ -98,6 +111,7 @@ def get_parser():
parse_args = parser.parse_args()

# delay the imports so running train.py -h doesn't take 5,234,807 years
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import (Activation, AveragePooling2D, Dense, Embedding,
Flatten, Input, Lambda, UpSampling2D)
Expand All @@ -111,7 +125,9 @@ def get_parser():
from ops import (minibatch_discriminator, minibatch_output_shape, Dense3D,
calculate_energy, scale, inpainting_attention)

from architectures import build_generator, build_discriminator
from exp_architectures import build_generator, build_discriminator
from exp_qcbm import (qcbm_approx_probs, qcbm_probs, initialize_weights,
train_qcbm, SPSA_grad, KL_Loss)

# batch, latent size, and whether or not to be verbose with a progress bar

Expand All @@ -130,6 +146,10 @@ def get_parser():
nb_epochs = parse_args.nb_epochs
batch_size = parse_args.batch_size
latent_size = parse_args.latent_size
nb_qubits = parse_args.nb_qubits
qcbm_nb_layers = parse_args.qcbm_nb_layer
qcbm_nb_shots = parse_args.qcbm_nb_shots
nb_samples = parse_args.nb_samples
verbose = parse_args.prog_bar
no_attn = parse_args.no_attn

Expand All @@ -139,11 +159,15 @@ def get_parser():

yaml_file = parse_args.dataset

if nb_qubits > 0:
latent_size = 2**nb_qubits

logger.debug('parameter configuration:')

logger.debug('number of epochs = {}'.format(nb_epochs))
logger.debug('batch size = {}'.format(batch_size))
logger.debug('latent size = {}'.format(latent_size))
logger.debug('number of image samples = {}'.format(nb_samples))
logger.debug('progress bar enabled = {}'.format(verbose))
logger.debug('Using attention = {}'.format(no_attn == False))
logger.debug('discriminator learning rate = {}'.format(disc_lr))
Expand All @@ -170,11 +194,12 @@ def _load_data(particle, datafile):
d = h5py.File(datafile, 'r')

# make our calo images channels-last
first = np.expand_dims(d['layer_0'][:], -1)
second = np.expand_dims(d['layer_1'][:], -1)
third = np.expand_dims(d['layer_2'][:], -1)
first = np.expand_dims(d['layer_0'][:nb_samples], -1)
second = np.expand_dims(d['layer_1'][:nb_samples], -1)
third = np.expand_dims(d['layer_2'][:nb_samples], -1)

# convert to MeV
energy = d['energy'][:].reshape(-1, 1) * 1000
energy = d['energy'][:nb_samples].reshape(-1, 1) * 1000

sizes = [
first.shape[1], first.shape[2],
Expand Down Expand Up @@ -273,7 +298,9 @@ def _load_data(particle, datafile):
mbd_energy
])

fake = Dense(1, activation='sigmoid', name='fakereal_output')(p)
qcbm_w = Dense(2**nb_qubits, activation='linear', name='qcbm')(p)

fake = Dense(1, activation='sigmoid', name='fakereal_output')(qcbm_w)
discriminator_outputs = [fake, total_energy]
discriminator_losses = ['binary_crossentropy', 'mae']
# ACGAN case
Expand All @@ -292,6 +319,18 @@ def _load_data(particle, datafile):

discriminator = Model(calorimeter + [input_energy], discriminator_outputs)

tf.keras.utils.plot_model(
discriminator,
to_file="discriminator.png",
show_shapes=True,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=False,
dpi=96,
)


discriminator.compile(
optimizer=Adam(lr=disc_lr, beta_1=adam_beta_1),
loss=discriminator_losses
Expand Down Expand Up @@ -370,6 +409,7 @@ def _load_data(particle, datafile):

logger.info('commencing training')

qcbm_weights = initialize_weights(qcbm_nb_layers, nb_qubits)
for epoch in range(nb_epochs):
logger.info('Epoch {} of {}'.format(epoch + 1, nb_epochs))

Expand All @@ -389,8 +429,18 @@ def _load_data(particle, datafile):
elif index % 10 == 0:
logger.debug('processed {}/{} batches'.format(index + 1, nb_batches))

# generate a new batch of noise
noise = np.random.normal(0, 1, (batch_size, latent_size))
# sample from QCBM
if nb_qubits > 0:
logger.info('sampling prior from QCBM...')
noise = qcbm_approx_probs(qcbm_weights, nb_qubits)
noise = np.array([i for i in noise.values()])
noise = np.concatenate((noise,np.zeros(latent_size-noise.size)))
logger.info(noise)
logger.info(noise.shape)
noise = np.tile(noise, (batch_size, 1))
logger.info(noise.shape)
else:
noise = np.random.normal(0, 1, (batch_size, latent_size))

# get a batch of real images
image_batch_1 = first[index * batch_size:(index + 1) * batch_size]
Expand Down Expand Up @@ -477,8 +527,14 @@ def _load_data(particle, datafile):
epoch + 1, np.mean(epoch_disc_loss, axis=0)))

# save weights every epoch
generator.save_weights('{0}{1:03d}.hdf5'.format(parse_args.g_pfx, epoch),
generator.save_weights('./weights/{0}{1:03d}.hdf5'.format(parse_args.g_pfx, epoch),
overwrite=True)

discriminator.save_weights('{0}{1:03d}.hdf5'.format(parse_args.d_pfx, epoch),
discriminator.save_weights('./weights/{0}{1:03d}.hdf5'.format(parse_args.d_pfx, epoch),
overwrite=True)

dis_weights_f = h5py.File('./weights/{0}{1:03d}.hdf5'.format(parse_args.d_pfx, epoch), 'r')
qcbm_dis_weights = dis_weights_f['fakereal_output']['fakereal_output']['kernel:0'][:].flatten()
logger.info("discriminator qcbm weights ({}): {}".format(qcbm_dis_weights.shape,qcbm_dis_weights))
qcbm_weights = train_qcbm(qcbm_dis_weights, qcbm_weights)
dis_weights_f.close()
File renamed without changes.

0 comments on commit ff7fe73

Please sign in to comment.