Switch the eager GAN MNIST example to object-based checkpointing
authorAllen Lavoie <allenl@google.com>
Wed, 7 Mar 2018 17:51:14 +0000 (09:51 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 17:59:34 +0000 (09:59 -0800)
- Removes variable_scopes, since they're no longer necessary (duplicate variable names are OK)
- Switches up the counters a bit (global_step -> step_counter, checkpoint the epoch counter)

PiperOrigin-RevId: 188190128

tensorflow/contrib/eager/python/examples/gan/mnist.py
tensorflow/contrib/eager/python/examples/gan/mnist_test.py

index 5f51d52..2b7e199 100644 (file)
@@ -195,7 +195,8 @@ def generator_loss(discriminator_gen_outputs):
 
 
 def train_one_epoch(generator, discriminator, generator_optimizer,
-                    discriminator_optimizer, dataset, log_interval, noise_dim):
+                    discriminator_optimizer, dataset, step_counter,
+                    log_interval, noise_dim):
   """Trains `generator` and `discriminator` models on `dataset`.
 
   Args:
@@ -204,7 +205,8 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
     generator_optimizer: Optimizer to use for generator.
     discriminator_optimizer: Optimizer to use for discriminator.
     dataset: Dataset of images to train on.
-    log_interval: How many global steps to wait between logging and collecting
+    step_counter: An integer variable, used to write summaries regularly.
+    log_interval: How many steps to wait between logging and collecting
       summaries.
     noise_dim: Dimension of noise vector to use.
   """
@@ -213,9 +215,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
   total_discriminator_loss = 0.0
   for (batch_index, images) in enumerate(tfe.Iterator(dataset)):
     with tf.device('/cpu:0'):
-      tf.assign_add(tf.train.get_global_step(), 1)
+      tf.assign_add(step_counter, 1)
 
-    with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
+    with tf.contrib.summary.record_summaries_every_n_global_steps(
+        log_interval, global_step=step_counter):
       current_batch_size = images.shape[0]
       noise = tf.random_uniform(
           shape=[current_batch_size, noise_dim],
@@ -243,12 +246,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
       discriminator_grad = g.gradient(discriminator_loss_val,
                                       discriminator.variables)
 
-      with tf.variable_scope('generator'):
-        generator_optimizer.apply_gradients(
-            zip(generator_grad, generator.variables))
-      with tf.variable_scope('discriminator'):
-        discriminator_optimizer.apply_gradients(
-            zip(discriminator_grad, discriminator.variables))
+      generator_optimizer.apply_gradients(
+          zip(generator_grad, generator.variables))
+      discriminator_optimizer.apply_gradients(
+          zip(discriminator_grad, discriminator.variables))
 
       if log_interval and batch_index > 0 and batch_index % log_interval == 0:
         print('Batch #%d\tAverage Generator Loss: %.6f\t'
@@ -269,13 +270,14 @@ def main(_):
       tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000)
       .batch(FLAGS.batch_size))
 
-  # Create the models and optimizers
-  generator = Generator(data_format)
-  discriminator = Discriminator(data_format)
-  with tf.variable_scope('generator'):
-    generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
-  with tf.variable_scope('discriminator'):
-    discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+  # Create the models and optimizers.
+  model_objects = {
+      'generator': Generator(data_format),
+      'discriminator': Discriminator(data_format),
+      'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
+      'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
+      'step_counter': tf.train.get_or_create_global_step(),
+  }
 
   # Prepare summary writer and checkpoint info
   summary_writer = tf.contrib.summary.create_summary_file_writer(
@@ -284,25 +286,22 @@ def main(_):
   latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
   if latest_cpkt:
     print('Using latest checkpoint at ' + latest_cpkt)
+  checkpoint = tfe.Checkpoint(**model_objects)
+  # Restore variables on creation if a checkpoint exists.
+  checkpoint.restore(latest_cpkt)
 
   with tf.device(device):
-    for epoch in range(1, 101):
-      with tfe.restore_variables_on_create(latest_cpkt):
-        global_step = tf.train.get_or_create_global_step()
-        start = time.time()
-        with summary_writer.as_default():
-          train_one_epoch(generator, discriminator, generator_optimizer,
-                          discriminator_optimizer, dataset, FLAGS.log_interval,
-                          FLAGS.noise)
-        end = time.time()
-        print('\nTrain time for epoch #%d (global step %d): %f' %
-              (epoch, global_step.numpy(), end - start))
-
-      all_variables = (
-          generator.variables + discriminator.variables +
-          generator_optimizer.variables() +
-          discriminator_optimizer.variables() + [global_step])
-      tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
+    for _ in range(100):
+      start = time.time()
+      with summary_writer.as_default():
+        train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval,
+                        noise_dim=FLAGS.noise, **model_objects)
+      end = time.time()
+      checkpoint.save(checkpoint_prefix)
+      print('\nTrain time for epoch #%d (step %d): %f' %
+            (checkpoint.save_counter.numpy(),
+             checkpoint.step_counter.numpy(),
+             end - start))
 
 
 if __name__ == '__main__':
index 4a3ca8d..bd35e50 100644 (file)
@@ -62,7 +62,7 @@ class MnistEagerGanBenchmark(tf.test.Benchmark):
                         for _ in range(measure_batches)]
       measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images)
 
-      tf.train.get_or_create_global_step()
+      step_counter = tf.train.get_or_create_global_step()
       with tf.device(device()):
         # Create the models and optimizers
         generator = mnist.Generator(data_format())
@@ -78,13 +78,15 @@ class MnistEagerGanBenchmark(tf.test.Benchmark):
           # warm up
           mnist.train_one_epoch(generator, discriminator, generator_optimizer,
                                 discriminator_optimizer,
-                                burn_dataset, log_interval=SUMMARY_INTERVAL,
+                                burn_dataset, step_counter,
+                                log_interval=SUMMARY_INTERVAL,
                                 noise_dim=NOISE_DIM)
           # measure
           start = time.time()
           mnist.train_one_epoch(generator, discriminator, generator_optimizer,
                                 discriminator_optimizer,
-                                measure_dataset, log_interval=SUMMARY_INTERVAL,
+                                measure_dataset, step_counter,
+                                log_interval=SUMMARY_INTERVAL,
                                 noise_dim=NOISE_DIM)
           self._report('train', start, measure_batches, batch_size)