Convert the eager SPINN example to use tf.keras.Model and object-based checkpointing.
authorAllen Lavoie <allenl@google.com>
Fri, 23 Mar 2018 22:12:21 +0000 (15:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 25 Mar 2018 11:13:12 +0000 (04:13 -0700)
Uses a more recursive/functional tracking style which avoids numbering layers. Maybe this is too magical and we should adapt tf.keras.Sequential first? Let me know what you think.

PiperOrigin-RevId: 190282346

tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
third_party/examples/eager/spinn/spinn.py

index 081b0af..591d99e 100644 (file)
@@ -33,6 +33,7 @@ import tensorflow as tf
 import tensorflow.contrib.eager as tfe
 from tensorflow.contrib.eager.python.examples.spinn import data
 from third_party.examples.eager.spinn import spinn
+from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2
 from tensorflow.contrib.summary import summary_test_util
 from tensorflow.python.eager import test
 from tensorflow.python.framework import test_util
@@ -172,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
         right_in.append(tf.random_normal((1, size * 2)))
         tracking.append(tf.random_normal((1, tracker_size * 2)))
 
-      out = reducer(left_in, right_in, tracking=tracking)
+      out = reducer(left_in, right_in=right_in, tracking=tracking)
       self.assertEqual(batch_size, len(out))
       self.assertEqual(tf.float32, out[0].dtype)
       self.assertEqual((1, size * 2), out[0].shape)
@@ -226,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
       self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
 
       for _ in range(2):
-        out1, out2 = tracker(bufs, stacks)
+        out1, out2 = tracker(bufs, stacks=stacks)
         self.assertIsNone(out2)
         self.assertEqual(batch_size, len(out1))
         self.assertEqual(tf.float32, out1[0].dtype)
@@ -259,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
       self.assertEqual(tf.int64, transitions.dtype)
       self.assertEqual((num_transitions, 1), transitions.shape)
 
-      out = s(buffers, transitions, training=True)
+      out = s(buffers, transitions=transitions, training=True)
       self.assertEqual(tf.float32, out.dtype)
       self.assertEqual((1, embedding_dims), out.shape)
 
@@ -285,12 +286,15 @@ class SpinnTest(test_util.TensorFlowTestCase):
                                                          vocab_size)
 
       # Invoke model under non-training mode.
-      logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+      logits = model(
+          prem, premise_transition=prem_trans, hypothesis=hypo,
+          hypothesis_transition=hypo_trans, training=False)
       self.assertEqual(tf.float32, logits.dtype)
       self.assertEqual((batch_size, d_out), logits.shape)
 
       # Invoke model under training model.
-      logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
+      logits = model(prem, premise_transition=prem_trans, hypothesis=hypo,
+                     hypothesis_transition=hypo_trans, training=True)
       self.assertEqual(tf.float32, logits.dtype)
       self.assertEqual((batch_size, d_out), logits.shape)
 
@@ -421,8 +425,14 @@ class SpinnTest(test_util.TensorFlowTestCase):
 
     # 5. Verify that checkpoints exist and contains all the expected variables.
     self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
-    ckpt_variable_names = [
-        item[0] for item in checkpoint_utils.list_variables(config.logdir)]
+    object_graph_string = checkpoint_utils.load_variable(
+        config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
+    object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph()
+    object_graph.ParseFromString(object_graph_string)
+    ckpt_variable_names = set()
+    for node in object_graph.nodes:
+      for attribute in node.attributes:
+        ckpt_variable_names.add(attribute.full_name)
     self.assertIn("global_step", ckpt_variable_names)
     for v in trainer.variables:
       variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
index 8a1c7db..f8fb6ec 100644 (file)
@@ -51,6 +51,9 @@ import tensorflow.contrib.eager as tfe
 from tensorflow.contrib.eager.python.examples.spinn import data
 
 
+layers = tf.keras.layers
+
+
 def _bundle(lstm_iter):
   """Concatenate a list of Tensors along 1st axis and split result into two.
 
@@ -78,17 +81,16 @@ def _unbundle(state):
   return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)
 
 
-class Reducer(tfe.Network):
+# pylint: disable=not-callable
+class Reducer(tf.keras.Model):
   """A module that applies reduce operation on left and right vectors."""
 
   def __init__(self, size, tracker_size=None):
     super(Reducer, self).__init__()
-    self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None))
-    self.right = self.track_layer(
-        tf.layers.Dense(5 * size, activation=None, use_bias=False))
+    self.left = layers.Dense(5 * size, activation=None)
+    self.right = layers.Dense(5 * size, activation=None, use_bias=False)
     if tracker_size is not None:
-      self.track = self.track_layer(
-          tf.layers.Dense(5 * size, activation=None, use_bias=False))
+      self.track = layers.Dense(5 * size, activation=None, use_bias=False)
     else:
       self.track = None
 
@@ -123,7 +125,7 @@ class Reducer(tfe.Network):
     return h, c
 
 
-class Tracker(tfe.Network):
+class Tracker(tf.keras.Model):
   """A module that tracks the history of the sentence with an LSTM."""
 
   def __init__(self, tracker_size, predict):
@@ -134,10 +136,10 @@ class Tracker(tfe.Network):
       predict: (`bool`) Whether prediction mode is enabled.
     """
     super(Tracker, self).__init__()
-    self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size))
+    self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size)
     self._state_size = tracker_size
     if predict:
-      self._transition = self.track_layer(tf.layers.Dense(4))
+      self._transition = layers.Dense(4)
     else:
       self._transition = None
 
@@ -182,7 +184,7 @@ class Tracker(tfe.Network):
       return unbundled, None
 
 
-class SPINN(tfe.Network):
+class SPINN(tf.keras.Model):
   """Stack-augmented Parser-Interpreter Neural Network.
 
   See https://arxiv.org/abs/1603.06021 for more details.
@@ -204,9 +206,9 @@ class SPINN(tfe.Network):
     """
     super(SPINN, self).__init__()
     self.config = config
-    self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker))
+    self.reducer = Reducer(config.d_hidden, config.d_tracker)
     if config.d_tracker is not None:
-      self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict))
+      self.tracker = Tracker(config.d_tracker, config.predict)
     else:
       self.tracker = None
 
@@ -248,7 +250,7 @@ class SPINN(tfe.Network):
       trans = transitions[i]
       if self.tracker:
         # Invoke tracker to obtain the current tracker states for the sentences.
-        tracker_states, trans_hypothesis = self.tracker(buffers, stacks)
+        tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks)
         if trans_hypothesis:
           trans = tf.argmax(trans_hypothesis, axis=-1)
       else:
@@ -264,7 +266,8 @@ class SPINN(tfe.Network):
           trackings.append(tracking)
 
       if rights:
-        reducer_output = self.reducer(lefts, rights, trackings)
+        reducer_output = self.reducer(
+            lefts, right_in=rights, tracking=trackings)
         reduced = iter(reducer_output)
 
         for transition, stack in zip(trans, stacks):
@@ -273,7 +276,27 @@ class SPINN(tfe.Network):
     return _bundle([stack.pop() for stack in stacks])[0]
 
 
-class SNLIClassifier(tfe.Network):
+class Perceptron(tf.keras.Model):
+  """One layer of the SNLIClassifier multi-layer perceptron."""
+
+  def __init__(self, dimension, dropout_rate, previous_layer):
+    """Configure the Perceptron."""
+    super(Perceptron, self).__init__()
+    self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu)
+    self.batchnorm = layers.BatchNormalization()
+    self.dropout = layers.Dropout(rate=dropout_rate)
+    self.previous_layer = previous_layer
+
+  def call(self, x, training):
+    """Run previous Perceptron layers, then this one."""
+    x = self.previous_layer(x, training=training)
+    x = self.dense(x)
+    x = self.batchnorm(x, training=training)
+    x = self.dropout(x, training=training)
+    return x
+
+
+class SNLIClassifier(tf.keras.Model):
   """SNLI Classifier Model.
 
   A model aimed at solving the SNLI (Standford Natural Language Inference)
@@ -304,29 +327,24 @@ class SNLIClassifier(tfe.Network):
     self.config = config
     self.embed = tf.constant(embed)
 
-    self.projection = self.track_layer(tf.layers.Dense(config.d_proj))
-    self.embed_bn = self.track_layer(tf.layers.BatchNormalization())
-    self.embed_dropout = self.track_layer(
-        tf.layers.Dropout(rate=config.embed_dropout))
-    self.encoder = self.track_layer(SPINN(config))
-
-    self.feature_bn = self.track_layer(tf.layers.BatchNormalization())
-    self.feature_dropout = self.track_layer(
-        tf.layers.Dropout(rate=config.mlp_dropout))
-
-    self.mlp_dense = []
-    self.mlp_bn = []
-    self.mlp_dropout = []
-    for _ in xrange(config.n_mlp_layers):
-      self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp)))
-      self.mlp_bn.append(
-          self.track_layer(tf.layers.BatchNormalization()))
-      self.mlp_dropout.append(
-          self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout)))
-    self.mlp_output = self.track_layer(tf.layers.Dense(
+    self.projection = layers.Dense(config.d_proj)
+    self.embed_bn = layers.BatchNormalization()
+    self.embed_dropout = layers.Dropout(rate=config.embed_dropout)
+    self.encoder = SPINN(config)
+
+    self.feature_bn = layers.BatchNormalization()
+    self.feature_dropout = layers.Dropout(rate=config.mlp_dropout)
+
+    current_mlp = lambda result, training: result
+    for _ in range(config.n_mlp_layers):
+      current_mlp = Perceptron(dimension=config.d_mlp,
+                               dropout_rate=config.mlp_dropout,
+                               previous_layer=current_mlp)
+    self.mlp = current_mlp
+    self.mlp_output = layers.Dense(
         config.d_out,
         kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
-                                                         maxval=5e-3)))
+                                                         maxval=5e-3))
 
   def call(self,
            premise,
@@ -370,10 +388,10 @@ class SNLIClassifier(tfe.Network):
 
     # Run the batch-normalized and dropout-processed word vectors through the
     # SPINN encoder.
-    premise = self.encoder(premise_embed, premise_transition,
-                           training=training)
-    hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
-                              training=training)
+    premise = self.encoder(
+        premise_embed, transitions=premise_transition, training=training)
+    hypothesis = self.encoder(
+        hypothesis_embed, transitions=hypothesis_transition, training=training)
 
     # Combine encoder outputs for premises and hypotheses into logits.
     # Then apply batch normalization and dropuout on the logits.
@@ -383,15 +401,12 @@ class SNLIClassifier(tfe.Network):
         self.feature_bn(logits, training=training), training=training)
 
     # Apply the multi-layer perceptron on the logits.
-    for dense, bn, dropout in zip(
-        self.mlp_dense, self.mlp_bn, self.mlp_dropout):
-      logits = tf.nn.elu(dense(logits))
-      logits = dropout(bn(logits, training=training), training=training)
+    logits = self.mlp(logits, training=training)
     logits = self.mlp_output(logits)
     return logits
 
 
-class SNLIClassifierTrainer(object):
+class SNLIClassifierTrainer(tfe.Checkpointable):
   """A class that coordinates the training of an SNLIClassifier."""
 
   def __init__(self, snli_classifier, lr):
@@ -450,10 +465,11 @@ class SNLIClassifierTrainer(object):
     """
     with tfe.GradientTape() as tape:
       tape.watch(self._model.variables)
+      # TODO(allenl): Allow passing Layer inputs as position arguments.
       logits = self._model(premise,
-                           premise_transition,
-                           hypothesis,
-                           hypothesis_transition,
+                           premise_transition=premise_transition,
+                           hypothesis=hypothesis,
+                           hypothesis_transition=hypothesis_transition,
                            training=True)
       loss = self.loss(labels, logits)
     gradients = tape.gradient(loss, self._model.variables)
@@ -517,7 +533,9 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
       snli_data, batch_size):
     if use_gpu:
       label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
-    logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
+    logits = trainer.model(
+        prem, premise_transition=prem_trans, hypothesis=hypo,
+        hypothesis_transition=hypo_trans, training=False)
     loss_val = trainer.loss(label, logits)
     batch_size = tf.shape(label)[0]
     mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -609,29 +627,30 @@ def train_or_infer_spinn(embed,
   with tf.device(device), \
        summary_writer.as_default(), \
        tf.contrib.summary.always_record_summaries():
-    with tfe.restore_variables_on_create(
-        tf.train.latest_checkpoint(config.logdir)):
-      model = SNLIClassifier(config, embed)
-      global_step = tf.train.get_or_create_global_step()
-      trainer = SNLIClassifierTrainer(model, config.lr)
+    model = SNLIClassifier(config, embed)
+    global_step = tf.train.get_or_create_global_step()
+    trainer = SNLIClassifierTrainer(model, config.lr)
+    checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step)
+    checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
 
     if inference_sentence_pair:
       # Inference mode.
-      with tfe.restore_variables_on_create(
-          tf.train.latest_checkpoint(config.logdir)):
-        prem, prem_trans = inference_sentence_pair[0]
-        hypo, hypo_trans = inference_sentence_pair[1]
-        hypo_trans = inference_sentence_pair[1][1]
-        inference_logits = model(  # pylint: disable=not-callable
-            tf.constant(prem), tf.constant(prem_trans),
-            tf.constant(hypo), tf.constant(hypo_trans), training=False)
-        inference_logits = inference_logits[0][1:]
-        max_index = tf.argmax(inference_logits)
-        print("\nInference logits:")
-        for i, (label, logit) in enumerate(
-            zip(data.POSSIBLE_LABELS, inference_logits)):
-          winner_tag = " (winner)" if max_index == i else ""
-          print("  {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
+      prem, prem_trans = inference_sentence_pair[0]
+      hypo, hypo_trans = inference_sentence_pair[1]
+      hypo_trans = inference_sentence_pair[1][1]
+      inference_logits = model(
+          tf.constant(prem),
+          premise_transition=tf.constant(prem_trans),
+          hypothesis=tf.constant(hypo),
+          hypothesis_transition=tf.constant(hypo_trans),
+          training=False)
+      inference_logits = inference_logits[0][1:]
+      max_index = tf.argmax(inference_logits)
+      print("\nInference logits:")
+      for i, (label, logit) in enumerate(
+          zip(data.POSSIBLE_LABELS, inference_logits)):
+        winner_tag = " (winner)" if max_index == i else ""
+        print("  {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
       return inference_logits
 
     train_len = train_data.num_batches(config.batch_size)
@@ -650,20 +669,15 @@ def train_or_infer_spinn(embed,
           # remain on CPU. Same in _evaluate_on_dataset().
 
         iterations += 1
-        with tfe.restore_variables_on_create(
-            tf.train.latest_checkpoint(config.logdir)):
-          batch_train_loss, batch_train_logits = trainer.train_batch(
-              label, prem, prem_trans, hypo, hypo_trans)
+        batch_train_loss, batch_train_logits = trainer.train_batch(
+            label, prem, prem_trans, hypo, hypo_trans)
         batch_size = tf.shape(label)[0]
         mean_loss(batch_train_loss.numpy(),
                   weights=batch_size.gpu() if use_gpu else batch_size)
         accuracy(tf.argmax(batch_train_logits, axis=1), label)
 
         if iterations % config.save_every == 0:
-          all_variables = trainer.variables + [global_step]
-          saver = tfe.Saver(all_variables)
-          saver.save(os.path.join(config.logdir, "ckpt"),
-                     global_step=global_step)
+          checkpoint.save(os.path.join(config.logdir, "ckpt"))
 
         if iterations % config.dev_every == 0:
           dev_loss, dev_frac_correct = _evaluate_on_dataset(