Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QC-AAN #8

Merged
merged 7 commits into from
Feb 26, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions CaloGAN/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,16 @@ def get_parser():
help='batch size per update')

parser.add_argument('--latent-size', action='store', type=int, default=1024,
help='size of QCBM prior latent space to sample')
help='size of classical prior from N(0,1) to sample')

parser.add_argument('--nb-qubits', action='store', type=int, default=8,
help='number of qubits to use for QCBM')

parser.add_argument('--qcbm-nb-layer', action='store', type=int, default=7,
help='number of layers for QCBM ansatz')

parser.add_argument('--qcbm-nb-shots', action='store', type=int, default=20000,
help='number of shots for QCBM')

parser.add_argument('--nb-samples', action='store', type=int, default=-1,
help='number of samples to train')
Expand Down Expand Up @@ -115,6 +124,7 @@ def get_parser():
calculate_energy, scale, inpainting_attention)

from architectures import build_generator, build_discriminator
from qcbm import qcbm_approx_probs, qcbm_probs, initialize_weights

# batch, latent size, and whether or not to be verbose with a progress bar

Expand All @@ -133,6 +143,9 @@ def get_parser():
nb_epochs = parse_args.nb_epochs
batch_size = parse_args.batch_size
latent_size = parse_args.latent_size
nb_qubits = parse_args.nb_qubits
qcbm_nb_layers = parse_args.qcbm_nb_layer
qcbm_nb_shots = parse_args.qcbm_nb_shots
nb_samples = parse_args.nb_samples
verbose = parse_args.prog_bar
no_attn = parse_args.no_attn
Expand All @@ -143,12 +156,15 @@ def get_parser():

yaml_file = parse_args.dataset

if nb_qubits > 0:
latent_size = 2**nb_qubits

logger.debug('parameter configuration:')

logger.debug('number of epochs = {}'.format(nb_epochs))
logger.debug('batch size = {}'.format(batch_size))
logger.debug('latent size = {}'.format(latent_size))
logger.debug('number of samples = {}'.format(latent_size))
logger.debug('number of image samples = {}'.format(nb_samples))
logger.debug('progress bar enabled = {}'.format(verbose))
logger.debug('Using attention = {}'.format(no_attn == False))
logger.debug('discriminator learning rate = {}'.format(disc_lr))
Expand Down Expand Up @@ -376,6 +392,7 @@ def _load_data(particle, datafile):

logger.info('commencing training')

qcbm_weights = initialize_weights(qcbm_nb_layers, nb_qubits)
for epoch in range(nb_epochs):
logger.info('Epoch {} of {}'.format(epoch + 1, nb_epochs))

Expand All @@ -395,8 +412,15 @@ def _load_data(particle, datafile):
elif index % 10 == 0:
logger.debug('processed {}/{} batches'.format(index + 1, nb_batches))

# generate a new batch of noise
noise = np.random.normal(0, 1, (batch_size, latent_size))
# sample from QCBM
if nb_qubits > 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly this part
where it samples from QCBM

logger.info('sampling prior from QCBM...')
noise = qcbm_approx_probs(qcbm_weights, nb_qubits)
noise = np.array([i for i in noise.values()])
noise = np.concatenate((noise,np.zeros(latent_size-noise.size)))
noise = np.tile(noise, (batch_size, 1))
else:
noise = np.random.normal(0, 1, (batch_size, latent_size))

# get a batch of real images
image_batch_1 = first[index * batch_size:(index + 1) * batch_size]
Expand Down