def train_one_epoch(generator, discriminator, generator_optimizer,
- discriminator_optimizer, dataset, log_interval, noise_dim):
+ discriminator_optimizer, dataset, step_counter,
+ log_interval, noise_dim):
"""Trains `generator` and `discriminator` models on `dataset`.
Args:
generator_optimizer: Optimizer to use for generator.
discriminator_optimizer: Optimizer to use for discriminator.
dataset: Dataset of images to train on.
- log_interval: How many global steps to wait between logging and collecting
+ step_counter: An integer variable, used to write summaries regularly.
+ log_interval: How many steps to wait between logging and collecting
summaries.
noise_dim: Dimension of noise vector to use.
"""
total_discriminator_loss = 0.0
for (batch_index, images) in enumerate(tfe.Iterator(dataset)):
with tf.device('/cpu:0'):
- tf.assign_add(tf.train.get_global_step(), 1)
+ tf.assign_add(step_counter, 1)
- with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
+ with tf.contrib.summary.record_summaries_every_n_global_steps(
+ log_interval, global_step=step_counter):
current_batch_size = images.shape[0]
noise = tf.random_uniform(
shape=[current_batch_size, noise_dim],
discriminator_grad = g.gradient(discriminator_loss_val,
discriminator.variables)
- with tf.variable_scope('generator'):
- generator_optimizer.apply_gradients(
- zip(generator_grad, generator.variables))
- with tf.variable_scope('discriminator'):
- discriminator_optimizer.apply_gradients(
- zip(discriminator_grad, discriminator.variables))
+ generator_optimizer.apply_gradients(
+ zip(generator_grad, generator.variables))
+ discriminator_optimizer.apply_gradients(
+ zip(discriminator_grad, discriminator.variables))
if log_interval and batch_index > 0 and batch_index % log_interval == 0:
print('Batch #%d\tAverage Generator Loss: %.6f\t'
tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000)
.batch(FLAGS.batch_size))
- # Create the models and optimizers
- generator = Generator(data_format)
- discriminator = Discriminator(data_format)
- with tf.variable_scope('generator'):
- generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
- with tf.variable_scope('discriminator'):
- discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+ # Create the models and optimizers.
+ model_objects = {
+ 'generator': Generator(data_format),
+ 'discriminator': Discriminator(data_format),
+ 'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
+ 'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
+ 'step_counter': tf.train.get_or_create_global_step(),
+ }
# Prepare summary writer and checkpoint info
summary_writer = tf.contrib.summary.create_summary_file_writer(
latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_cpkt:
print('Using latest checkpoint at ' + latest_cpkt)
+ checkpoint = tfe.Checkpoint(**model_objects)
+ # Restore variables on creation if a checkpoint exists.
+ checkpoint.restore(latest_cpkt)
with tf.device(device):
- for epoch in range(1, 101):
- with tfe.restore_variables_on_create(latest_cpkt):
- global_step = tf.train.get_or_create_global_step()
- start = time.time()
- with summary_writer.as_default():
- train_one_epoch(generator, discriminator, generator_optimizer,
- discriminator_optimizer, dataset, FLAGS.log_interval,
- FLAGS.noise)
- end = time.time()
- print('\nTrain time for epoch #%d (global step %d): %f' %
- (epoch, global_step.numpy(), end - start))
-
- all_variables = (
- generator.variables + discriminator.variables +
- generator_optimizer.variables() +
- discriminator_optimizer.variables() + [global_step])
- tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
+ for _ in range(100):
+ start = time.time()
+ with summary_writer.as_default():
+ train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval,
+ noise_dim=FLAGS.noise, **model_objects)
+ end = time.time()
+ checkpoint.save(checkpoint_prefix)
+ print('\nTrain time for epoch #%d (step %d): %f' %
+ (checkpoint.save_counter.numpy(),
+ checkpoint.step_counter.numpy(),
+ end - start))
if __name__ == '__main__':
for _ in range(measure_batches)]
measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images)
- tf.train.get_or_create_global_step()
+ step_counter = tf.train.get_or_create_global_step()
with tf.device(device()):
# Create the models and optimizers
generator = mnist.Generator(data_format())
# warm up
mnist.train_one_epoch(generator, discriminator, generator_optimizer,
discriminator_optimizer,
- burn_dataset, log_interval=SUMMARY_INTERVAL,
+ burn_dataset, step_counter,
+ log_interval=SUMMARY_INTERVAL,
noise_dim=NOISE_DIM)
# measure
start = time.time()
mnist.train_one_epoch(generator, discriminator, generator_optimizer,
discriminator_optimizer,
- measure_dataset, log_interval=SUMMARY_INTERVAL,
+ measure_dataset, step_counter,
+ log_interval=SUMMARY_INTERVAL,
noise_dim=NOISE_DIM)
self._report('train', start, measure_batches, batch_size)