eager: Change various examples to use tf.keras.Model instead of tfe.Network.
authorAsim Shankar <ashankar@google.com>
Fri, 23 Feb 2018 23:45:02 +0000 (15:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 23:51:03 +0000 (15:51 -0800)
PiperOrigin-RevId: 186834891

tensorflow/contrib/eager/python/examples/gan/mnist.py
tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py

index b9ac79f..5f51d52 100644 (file)
@@ -35,7 +35,7 @@ from tensorflow.examples.tutorials.mnist import input_data
 FLAGS = None
 
 
-class Discriminator(tfe.Network):
+class Discriminator(tf.keras.Model):
   """GAN Discriminator.
 
   A network to differentiate between generated and real handwritten digits.
@@ -56,19 +56,15 @@ class Discriminator(tfe.Network):
     else:
       assert data_format == 'channels_last'
       self._input_shape = [-1, 28, 28, 1]
-    self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME',
-                                                   data_format=data_format,
-                                                   activation=tf.tanh))
-    self.pool1 = self.track_layer(
-        tf.layers.AveragePooling2D(2, 2, data_format=data_format))
-    self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5,
-                                                   data_format=data_format,
-                                                   activation=tf.tanh))
-    self.pool2 = self.track_layer(
-        tf.layers.AveragePooling2D(2, 2, data_format=data_format))
-    self.flatten = self.track_layer(tf.layers.Flatten())
-    self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh))
-    self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None))
+    self.conv1 = tf.layers.Conv2D(
+        64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
+    self.pool1 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
+    self.conv2 = tf.layers.Conv2D(
+        128, 5, data_format=data_format, activation=tf.tanh)
+    self.pool2 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
+    self.flatten = tf.layers.Flatten()
+    self.fc1 = tf.layers.Dense(1024, activation=tf.tanh)
+    self.fc2 = tf.layers.Dense(1, activation=None)
 
   def call(self, inputs):
     """Return two logits per image estimating input authenticity.
@@ -95,7 +91,7 @@ class Discriminator(tfe.Network):
     return x
 
 
-class Generator(tfe.Network):
+class Generator(tf.keras.Model):
   """Generator of handwritten digits similar to the ones in the MNIST dataset.
   """
 
@@ -116,18 +112,17 @@ class Generator(tfe.Network):
     else:
       assert data_format == 'channels_last'
       self._pre_conv_shape = [-1, 6, 6, 128]
-    self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128,
-                                                activation=tf.tanh))
+    self.fc1 = tf.layers.Dense(6 * 6 * 128, activation=tf.tanh)
 
     # In call(), we reshape the output of fc1 to _pre_conv_shape
 
     # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
-    self.conv1 = self.track_layer(tf.layers.Conv2DTranspose(
-        64, 4, strides=2, activation=None, data_format=data_format))
+    self.conv1 = tf.layers.Conv2DTranspose(
+        64, 4, strides=2, activation=None, data_format=data_format)
 
     # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
-    self.conv2 = self.track_layer(tf.layers.Conv2DTranspose(
-        1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format))
+    self.conv2 = tf.layers.Conv2DTranspose(
+        1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
 
   def call(self, inputs):
     """Return a batch of generated images.
@@ -168,7 +163,8 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
   """
 
   loss_on_real = tf.losses.sigmoid_cross_entropy(
-      tf.ones_like(discriminator_real_outputs), discriminator_real_outputs,
+      tf.ones_like(discriminator_real_outputs),
+      discriminator_real_outputs,
       label_smoothing=0.25)
   loss_on_generated = tf.losses.sigmoid_cross_entropy(
       tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
@@ -198,9 +194,8 @@ def generator_loss(discriminator_gen_outputs):
   return loss
 
 
-def train_one_epoch(generator, discriminator,
-                    generator_optimizer, discriminator_optimizer,
-                    dataset, log_interval, noise_dim):
+def train_one_epoch(generator, discriminator, generator_optimizer,
+                    discriminator_optimizer, dataset, log_interval, noise_dim):
   """Trains `generator` and `discriminator` models on `dataset`.
 
   Args:
@@ -222,14 +217,18 @@ def train_one_epoch(generator, discriminator,
 
     with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
       current_batch_size = images.shape[0]
-      noise = tf.random_uniform(shape=[current_batch_size, noise_dim],
-                                minval=-1., maxval=1., seed=batch_index)
+      noise = tf.random_uniform(
+          shape=[current_batch_size, noise_dim],
+          minval=-1.,
+          maxval=1.,
+          seed=batch_index)
 
       with tfe.GradientTape(persistent=True) as g:
         generated_images = generator(noise)
-        tf.contrib.summary.image('generated_images',
-                                 tf.reshape(generated_images, [-1, 28, 28, 1]),
-                                 max_images=10)
+        tf.contrib.summary.image(
+            'generated_images',
+            tf.reshape(generated_images, [-1, 28, 28, 1]),
+            max_images=10)
 
         discriminator_gen_outputs = discriminator(generated_images)
         discriminator_real_outputs = discriminator(images)
@@ -245,17 +244,17 @@ def train_one_epoch(generator, discriminator,
                                       discriminator.variables)
 
       with tf.variable_scope('generator'):
-        generator_optimizer.apply_gradients(zip(generator_grad,
-                                                generator.variables))
+        generator_optimizer.apply_gradients(
+            zip(generator_grad, generator.variables))
       with tf.variable_scope('discriminator'):
-        discriminator_optimizer.apply_gradients(zip(discriminator_grad,
-                                                    discriminator.variables))
+        discriminator_optimizer.apply_gradients(
+            zip(discriminator_grad, discriminator.variables))
 
       if log_interval and batch_index > 0 and batch_index % log_interval == 0:
         print('Batch #%d\tAverage Generator Loss: %.6f\t'
-              'Average Discriminator Loss: %.6f' % (
-                  batch_index, total_generator_loss/batch_index,
-                  total_discriminator_loss/batch_index))
+              'Average Discriminator Loss: %.6f' %
+              (batch_index, total_generator_loss / batch_index,
+               total_discriminator_loss / batch_index))
 
 
 def main(_):
@@ -266,10 +265,9 @@ def main(_):
 
   # Load the datasets
   data = input_data.read_data_sets(FLAGS.data_dir)
-  dataset = (tf.data.Dataset
-             .from_tensor_slices(data.train.images)
-             .shuffle(60000)
-             .batch(FLAGS.batch_size))
+  dataset = (
+      tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000)
+      .batch(FLAGS.batch_size))
 
   # Create the models and optimizers
   generator = Generator(data_format)
@@ -294,20 +292,17 @@ def main(_):
         start = time.time()
         with summary_writer.as_default():
           train_one_epoch(generator, discriminator, generator_optimizer,
-                          discriminator_optimizer,
-                          dataset, FLAGS.log_interval, FLAGS.noise)
+                          discriminator_optimizer, dataset, FLAGS.log_interval,
+                          FLAGS.noise)
         end = time.time()
-        print('\nTrain time for epoch #%d (global step %d): %f' % (
-            epoch, global_step.numpy(), end - start))
+        print('\nTrain time for epoch #%d (global step %d): %f' %
+              (epoch, global_step.numpy(), end - start))
 
       all_variables = (
-          generator.variables
-          + discriminator.variables
-          + generator_optimizer.variables()
-          + discriminator_optimizer.variables()
-          + [global_step])
-      tfe.Saver(all_variables).save(
-          checkpoint_prefix, global_step=global_step)
+          generator.variables + discriminator.variables +
+          generator_optimizer.variables() +
+          discriminator_optimizer.variables() + [global_step])
+      tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
 
 
 if __name__ == '__main__':
index 6ce4de6..157a636 100644 (file)
@@ -33,23 +33,13 @@ import tensorflow as tf
 import tensorflow.contrib.eager as tfe
 
 
-class LinearModel(tfe.Network):
-  """A TensorFlow linear regression model.
-
-  Uses TensorFlow's eager execution.
-
-  For those familiar with TensorFlow graphs, notice the absence of
-  `tf.Session`. The `forward()` method here immediately executes and
-  returns output values. The `loss()` method immediately compares the
-  output of `forward()` with the target and returns the MSE loss value.
-  The `fit()` performs gradient-descent training on the model's weights
-  and bias.
-  """
+class LinearModel(tf.keras.Model):
+  """A TensorFlow linear regression model."""
 
   def __init__(self):
     """Constructs a LinearModel object."""
     super(LinearModel, self).__init__()
-    self._hidden_layer = self.track_layer(tf.layers.Dense(1))
+    self._hidden_layer = tf.layers.Dense(1)
 
   def call(self, xs):
     """Invoke the linear model.
index 9982fdb..6b59413 100644 (file)
@@ -27,10 +27,9 @@ from __future__ import print_function
 import functools
 
 import tensorflow as tf
-import tensorflow.contrib.eager as tfe
 
 
-class _IdentityBlock(tfe.Network):
+class _IdentityBlock(tf.keras.Model):
   """_IdentityBlock is the block that has no conv layer at shortcut.
 
   Args:
@@ -50,31 +49,24 @@ class _IdentityBlock(tfe.Network):
     bn_name_base = 'bn' + str(stage) + block + '_branch'
     bn_axis = 1 if data_format == 'channels_first' else 3
 
-    self.conv2a = self.track_layer(
-        tf.layers.Conv2D(
-            filters1, (1, 1),
-            name=conv_name_base + '2a',
-            data_format=data_format))
-    self.bn2a = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
-
-    self.conv2b = self.track_layer(
-        tf.layers.Conv2D(
-            filters2,
-            kernel_size,
-            padding='same',
-            data_format=data_format,
-            name=conv_name_base + '2b'))
-    self.bn2b = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
-
-    self.conv2c = self.track_layer(
-        tf.layers.Conv2D(
-            filters3, (1, 1),
-            name=conv_name_base + '2c',
-            data_format=data_format))
-    self.bn2c = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
+    self.conv2a = tf.layers.Conv2D(
+        filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
+    self.bn2a = tf.layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2a')
+
+    self.conv2b = tf.layers.Conv2D(
+        filters2,
+        kernel_size,
+        padding='same',
+        data_format=data_format,
+        name=conv_name_base + '2b')
+    self.bn2b = tf.layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2b')
+
+    self.conv2c = tf.layers.Conv2D(
+        filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+    self.bn2c = tf.layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2c')
 
   def call(self, input_tensor, training=False):
     x = self.conv2a(input_tensor)
@@ -92,7 +84,7 @@ class _IdentityBlock(tfe.Network):
     return tf.nn.relu(x)
 
 
-class _ConvBlock(tfe.Network):
+class _ConvBlock(tf.keras.Model):
   """_ConvBlock is the block that has a conv layer at shortcut.
 
   Args:
@@ -121,41 +113,35 @@ class _ConvBlock(tfe.Network):
     bn_name_base = 'bn' + str(stage) + block + '_branch'
     bn_axis = 1 if data_format == 'channels_first' else 3
 
-    self.conv2a = self.track_layer(
-        tf.layers.Conv2D(
-            filters1, (1, 1),
-            strides=strides,
-            name=conv_name_base + '2a',
-            data_format=data_format))
-    self.bn2a = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
-
-    self.conv2b = self.track_layer(
-        tf.layers.Conv2D(
-            filters2,
-            kernel_size,
-            padding='same',
-            name=conv_name_base + '2b',
-            data_format=data_format))
-    self.bn2b = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
-
-    self.conv2c = self.track_layer(
-        tf.layers.Conv2D(
-            filters3, (1, 1),
-            name=conv_name_base + '2c',
-            data_format=data_format))
-    self.bn2c = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
-
-    self.conv_shortcut = self.track_layer(
-        tf.layers.Conv2D(
-            filters3, (1, 1),
-            strides=strides,
-            name=conv_name_base + '1',
-            data_format=data_format))
-    self.bn_shortcut = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1'))
+    self.conv2a = tf.layers.Conv2D(
+        filters1, (1, 1),
+        strides=strides,
+        name=conv_name_base + '2a',
+        data_format=data_format)
+    self.bn2a = tf.layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2a')
+
+    self.conv2b = tf.layers.Conv2D(
+        filters2,
+        kernel_size,
+        padding='same',
+        name=conv_name_base + '2b',
+        data_format=data_format)
+    self.bn2b = tf.layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2b')
+
+    self.conv2c = tf.layers.Conv2D(
+        filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+    self.bn2c = tf.layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2c')
+
+    self.conv_shortcut = tf.layers.Conv2D(
+        filters3, (1, 1),
+        strides=strides,
+        name=conv_name_base + '1',
+        data_format=data_format)
+    self.bn_shortcut = tf.layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '1')
 
   def call(self, input_tensor, training=False):
     x = self.conv2a(input_tensor)
@@ -176,7 +162,8 @@ class _ConvBlock(tfe.Network):
     return tf.nn.relu(x)
 
 
-class ResNet50(tfe.Network):
+# pylint: disable=not-callable
+class ResNet50(tf.keras.Model):
   """Instantiates the ResNet50 architecture.
 
   Args:
@@ -220,32 +207,28 @@ class ResNet50(tfe.Network):
     self.include_top = include_top
 
     def conv_block(filters, stage, block, strides=(2, 2)):
-      l = _ConvBlock(
+      return _ConvBlock(
           3,
           filters,
           stage=stage,
           block=block,
           data_format=data_format,
           strides=strides)
-      return self.track_layer(l)
 
     def id_block(filters, stage, block):
-      l = _IdentityBlock(
+      return _IdentityBlock(
           3, filters, stage=stage, block=block, data_format=data_format)
-      return self.track_layer(l)
-
-    self.conv1 = self.track_layer(
-        tf.layers.Conv2D(
-            64, (7, 7),
-            strides=(2, 2),
-            data_format=data_format,
-            padding='same',
-            name='conv1'))
+
+    self.conv1 = tf.layers.Conv2D(
+        64, (7, 7),
+        strides=(2, 2),
+        data_format=data_format,
+        padding='same',
+        name='conv1')
     bn_axis = 1 if data_format == 'channels_first' else 3
-    self.bn_conv1 = self.track_layer(
-        tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1'))
-    self.max_pool = self.track_layer(
-        tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format))
+    self.bn_conv1 = tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
+    self.max_pool = tf.layers.MaxPooling2D(
+        (3, 3), strides=(2, 2), data_format=data_format)
 
     self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
     self.l2b = id_block([64, 64, 256], stage=2, block='b')
@@ -267,13 +250,11 @@ class ResNet50(tfe.Network):
     self.l5b = id_block([512, 512, 2048], stage=5, block='b')
     self.l5c = id_block([512, 512, 2048], stage=5, block='c')
 
-    self.avg_pool = self.track_layer(
-        tf.layers.AveragePooling2D(
-            (7, 7), strides=(7, 7), data_format=data_format))
+    self.avg_pool = tf.layers.AveragePooling2D(
+        (7, 7), strides=(7, 7), data_format=data_format)
 
     if self.include_top:
-      self.fc1000 = self.track_layer(
-          tf.layers.Dense(classes, name='fc1000'))
+      self.fc1000 = tf.layers.Dense(classes, name='fc1000')
     else:
       reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
       reduction_indices = tf.constant(reduction_indices)
@@ -288,7 +269,7 @@ class ResNet50(tfe.Network):
       else:
         self.global_pooling = None
 
-  def call(self, input_tensor, training=False):
+  def call(self, input_tensor, training):
     x = self.conv1(input_tensor)
     x = self.bn_conv1(x, training=training)
     x = tf.nn.relu(x)
index 2331788..551c76b 100644 (file)
@@ -55,7 +55,7 @@ class ResNet50GraphTest(tf.test.TestCase):
     with tf.Graph().as_default():
       images = tf.placeholder(tf.float32, image_shape(None))
       model = resnet50.ResNet50(data_format())
-      predictions = model(images)
+      predictions = model(images, training=False)
 
       init = tf.global_variables_initializer()
 
@@ -114,7 +114,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
     with tf.Graph().as_default():
       images = tf.placeholder(tf.float32, image_shape(None))
       model = resnet50.ResNet50(data_format())
-      predictions = model(images)
+      predictions = model(images, training=False)
 
       init = tf.global_variables_initializer()
 
index 0ff8746..c106ab0 100644 (file)
@@ -71,7 +71,7 @@ class ResNet50Test(tf.test.TestCase):
       model.call = tfe.defun(model.call)
     with tf.device(device):
       images, _ = random_batch(2)
-      output = model(images)
+      output = model(images, training=False)
     self.assertEqual((2, 1000), output.shape)
 
   def test_apply(self):
@@ -85,7 +85,7 @@ class ResNet50Test(tf.test.TestCase):
     model = resnet50.ResNet50(data_format, include_top=False)
     with tf.device(device):
       images, _ = random_batch(2)
-      output = model(images)
+      output = model(images, training=False)
     output_shape = ((2, 2048, 1, 1)
                     if data_format == 'channels_first' else (2, 1, 1, 2048))
     self.assertEqual(output_shape, output.shape)
@@ -95,7 +95,7 @@ class ResNet50Test(tf.test.TestCase):
     model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
     with tf.device(device):
       images, _ = random_batch(2)
-      output = model(images)
+      output = model(images, training=False)
     self.assertEqual((2, 2048), output.shape)
 
   def test_train(self):