tfe SPINN example: Add inference; fix serialization
authorShanqing Cai <cais@google.com>
Wed, 14 Feb 2018 05:49:28 +0000 (21:49 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 05:52:54 +0000 (21:52 -0800)
* Also de-flake a test.

PiperOrigin-RevId: 185637742

tensorflow/contrib/eager/python/examples/spinn/BUILD
tensorflow/contrib/eager/python/examples/spinn/data.py
tensorflow/contrib/eager/python/examples/spinn/data_test.py
tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
third_party/examples/eager/spinn/README.md
third_party/examples/eager/spinn/spinn.py

index 21055cf..a1f8a75 100644 (file)
@@ -38,9 +38,5 @@ cuda_py_test(
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_test_lib",
     ],
-    tags = [
-        "manual",
-        "no_gpu",
-        "no_pip",  # because spinn.py is under third_party/.
-    ],
+    tags = ["no_pip"],  # because spinn.py is under third_party/.
 )
index fcaae0a..3bc3bb4 100644 (file)
@@ -227,6 +227,29 @@ def calculate_bins(length2count, min_bin_size):
   return bounds
 
 
+def encode_sentence(sentence, word2index):
+  """Encode a single sentence as word indices and shift-reduce code.
+
+  Args:
+    sentence: The sentence with added binary parse information, represented as
+      a string, with all the word items and parentheses separated by spaces.
+      E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'.
+    word2index: A `dict` mapping words to their word indices.
+
+  Returns:
+     1. Word indices as a numpy array, with shape `(sequence_len, 1)`.
+     2. Shift-reduce sequence as a numpy array, with shape
+       `(sequence_len * 2 - 3, 1)`.
+  """
+  items = [w for w in sentence.split(" ") if w]
+  words = get_non_parenthesis_words(items)
+  shift_reduce = get_shift_reduce(items)
+  word_indices = pad_and_reverse_word_ids(
+      [[word2index.get(word, UNK_CODE) for word in words]]).T
+  return (word_indices,
+          np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1))
+
+
 class SnliData(object):
   """A split of SNLI data."""
 
index e4f0b37..54fef2c 100644 (file)
@@ -22,6 +22,7 @@ import os
 import shutil
 import tempfile
 
+import numpy as np
 import tensorflow as tf
 
 from tensorflow.contrib.eager.python.examples.spinn import data
@@ -173,14 +174,9 @@ class DataTest(tf.test.TestCase):
         ValueError, "Cannot find GloVe embedding file at"):
       data.load_word_vectors(self._temp_data_dir, vocab)
 
-  def testSnliData(self):
-    """Unit test for SnliData objects."""
-    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
-    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
-    os.makedirs(snli_1_0_dir)
-
+  def _createFakeSnliData(self, fake_snli_file):
     # Four sentences in total.
-    with open(fake_train_file, "wt") as f:
+    with open(fake_snli_file, "wt") as f:
       f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
               "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
               "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
@@ -205,10 +201,7 @@ class DataTest(tf.test.TestCase):
               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
               "neutral\tentailment\tneutral\tneutral\tneutral\n")
 
-    glove_dir = os.path.join(self._temp_data_dir, "glove")
-    os.makedirs(glove_dir)
-    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
-
+  def _createFakeGloveData(self, glove_file):
     words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
     with open(glove_file, "wt") as f:
       for i, word in enumerate(words):
@@ -220,6 +213,40 @@ class DataTest(tf.test.TestCase):
           else:
             f.write("\n")
 
+  def testEncodeSingleSentence(self):
+    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
+    os.makedirs(snli_1_0_dir)
+    self._createFakeSnliData(fake_train_file)
+    vocab = data.load_vocabulary(self._temp_data_dir)
+    glove_dir = os.path.join(self._temp_data_dir, "glove")
+    os.makedirs(glove_dir)
+    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
+    self._createFakeGloveData(glove_file)
+    word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
+
+    sentence_variants = [
+        "( Foo ( ( bar baz ) . ) )",
+        " ( Foo ( ( bar baz ) . ) ) ",
+        "( Foo ( ( bar baz ) . )  )"]
+    for sentence in sentence_variants:
+      word_indices, shift_reduce = data.encode_sentence(sentence, word2index)
+      self.assertEqual(np.int64, word_indices.dtype)
+      self.assertEqual((5, 1), word_indices.shape)
+      self.assertAllClose(
+          np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce)
+
+  def testSnliData(self):
+    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
+    os.makedirs(snli_1_0_dir)
+    self._createFakeSnliData(fake_train_file)
+
+    glove_dir = os.path.join(self._temp_data_dir, "glove")
+    os.makedirs(glove_dir)
+    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
+    self._createFakeGloveData(glove_file)
+
     vocab = data.load_vocabulary(self._temp_data_dir)
     word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
 
@@ -230,7 +257,7 @@ class DataTest(tf.test.TestCase):
     self.assertEqual(1, train_data.num_batches(4))
 
     generator = train_data.get_generator(2)()
-    for i in range(2):
+    for _ in range(2):
       label, prem, prem_trans, hypo, hypo_trans = next(generator)
       self.assertEqual(2, len(label))
       self.assertEqual((4, 2), prem.shape)
index 7b2f09c..eefc06d 100644 (file)
@@ -36,6 +36,7 @@ from third_party.examples.eager.spinn import spinn
 from tensorflow.contrib.summary import summary_test_util
 from tensorflow.python.eager import test
 from tensorflow.python.framework import test_util
+from tensorflow.python.training import checkpoint_utils
 # pylint: enable=g-bad-import-order
 
 
@@ -66,13 +67,30 @@ def _generate_synthetic_snli_data_batch(sequence_length,
   return labels, prem, prem_trans, hypo, hypo_trans
 
 
-def _test_spinn_config(d_embed, d_out, logdir=None):
+def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None):
+  """Generate a config tuple for testing.
+
+  Args:
+    d_embed: Embedding dimensions.
+    d_out: Model output dimensions.
+    logdir: Optional logdir.
+    inference_sentences: A 2-tuple of strings representing the sentences (with
+      binary parsing result), e.g.,
+      ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )").
+
+  Returns:
+    A config tuple.
+  """
   config_tuple = collections.namedtuple(
       "Config", ["d_hidden", "d_proj", "d_tracker", "predict",
                  "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp",
                  "d_out", "projection", "lr", "batch_size", "epochs",
                  "force_cpu", "logdir", "log_every", "dev_every", "save_every",
-                 "lr_decay_every", "lr_decay_by"])
+                 "lr_decay_every", "lr_decay_by", "inference_premise",
+                 "inference_hypothesis"])
+
+  inference_premise = inference_sentences[0] if inference_sentences else None
+  inference_hypothesis = inference_sentences[1] if inference_sentences else None
   return config_tuple(
       d_hidden=d_embed,
       d_proj=d_embed * 2,
@@ -86,14 +104,16 @@ def _test_spinn_config(d_embed, d_out, logdir=None):
       projection=True,
       lr=2e-2,
       batch_size=2,
-      epochs=10,
+      epochs=20,
       force_cpu=False,
       logdir=logdir,
       log_every=1,
       dev_every=2,
       save_every=2,
       lr_decay_every=1,
-      lr_decay_by=0.75)
+      lr_decay_by=0.75,
+      inference_premise=inference_premise,
+      inference_hypothesis=inference_hypothesis)
 
 
 class SpinnTest(test_util.TensorFlowTestCase):
@@ -288,11 +308,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
       # Training on the batch should have led to a change in the loss value.
       self.assertNotEqual(loss1.numpy(), loss2.numpy())
 
-  def testTrainSpinn(self):
-    """Test with fake toy SNLI data and GloVe vectors."""
-
-    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
-    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+  def _create_test_data(self, snli_1_0_dir):
     fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
     os.makedirs(snli_1_0_dir)
 
@@ -337,13 +353,52 @@ class SpinnTest(test_util.TensorFlowTestCase):
           else:
             f.write("\n")
 
+    return fake_train_file
+
+  def testInferSpinnWorks(self):
+    """Test inference with the spinn model."""
+    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+    self._create_test_data(snli_1_0_dir)
+
+    vocab = data.load_vocabulary(self._temp_data_dir)
+    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
+
+    config = _test_spinn_config(
+        data.WORD_VECTOR_LEN, 4,
+        logdir=os.path.join(self._temp_data_dir, "logdir"),
+        inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )"))
+    logits = spinn.train_or_infer_spinn(
+        embed, word2index, None, None, None, config)
+    self.assertEqual(np.float32, logits.dtype)
+    self.assertEqual((3,), logits.shape)
+
+  def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
+    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+    self._create_test_data(snli_1_0_dir)
+
+    vocab = data.load_vocabulary(self._temp_data_dir)
+    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
+
+    config = _test_spinn_config(
+        data.WORD_VECTOR_LEN, 4,
+        logdir=os.path.join(self._temp_data_dir, "logdir"),
+        inference_sentences=("( foo ( bar . ) )", None))
+    with self.assertRaises(ValueError):
+      spinn.train_or_infer_spinn(embed, word2index, None, None, None, config)
+
+  def testTrainSpinn(self):
+    """Test with fake toy SNLI data and GloVe vectors."""
+
+    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
+    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+    fake_train_file = self._create_test_data(snli_1_0_dir)
+
     vocab = data.load_vocabulary(self._temp_data_dir)
     word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
 
     train_data = data.SnliData(fake_train_file, word2index)
     dev_data = data.SnliData(fake_train_file, word2index)
     test_data = data.SnliData(fake_train_file, word2index)
-    print(embed)
 
     # 2. Create a fake config.
     config = _test_spinn_config(
@@ -351,7 +406,8 @@ class SpinnTest(test_util.TensorFlowTestCase):
         logdir=os.path.join(self._temp_data_dir, "logdir"))
 
     # 3. Test training of a SPINN model.
-    spinn.train_spinn(embed, train_data, dev_data, test_data, config)
+    trainer = spinn.train_or_infer_spinn(
+        embed, word2index, train_data, dev_data, test_data, config)
 
     # 4. Load train loss values from the summary files and verify that they
     #    decrease with training.
@@ -363,6 +419,15 @@ class SpinnTest(test_util.TensorFlowTestCase):
     self.assertEqual(config.epochs, len(train_losses))
     self.assertLess(train_losses[-1], train_losses[0])
 
+    # 5. Verify that checkpoints exist and contains all the expected variables.
+    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
+    ckpt_variable_names = [
+        item[0] for item in checkpoint_utils.list_variables(config.logdir)]
+    self.assertIn("global_step", ckpt_variable_names)
+    for v in trainer.variables:
+      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
+      self.assertIn(variable_name, ckpt_variable_names)
+
 
 class EagerSpinnSNLIClassifierBenchmark(test.Benchmark):
 
index 6bd3d53..335c0fa 100644 (file)
@@ -66,3 +66,44 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth
   ```bash
   tensorboard --logdir /tmp/spinn-logs
   ```
+
+- After training, you may use the model to perform inference on input data in
+  the SNLI data format. The premise and hypotheses sentences are specified with
+  the command-line flags `--inference_premise` and `--inference_hypothesis`,
+  respecitvely. Each sentence should include the words, as well as parentheses
+  representing a binary parsing of the sentence. The words and parentheses
+  should all be separated by spaces. For instance,
+
+  ```bash
+  pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+      --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
+      --inference_hypothesis '( ( The dog ) ( moves . ) )'
+  ```
+
+  which will generate an output like the following, due to the semantic
+  consistency of the two sentences.
+
+  ```none
+  Inference logits:
+    entailment:     1.101249 (winner)
+    contradiction:  -2.374171
+    neutral:        -0.296733
+  ```
+
+  By contrast, the following sentence pair:
+
+  ```bash
+  pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+      --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
+      --inference_hypothesis '( ( The dog ) ( rests . ) )'
+  ```
+
+  will give you an output like the following, due to the semantic
+  contradiction of the two sentences.
+
+  ```none
+  Inference logits:
+    entailment:     -1.070098
+    contradiction:  2.798695 (winner)
+    neutral:        -1.402287
+  ```
index a2fa18e..38ba48d 100644 (file)
@@ -44,6 +44,7 @@ import os
 import sys
 import time
 
+import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import tensorflow as tf
 
@@ -471,6 +472,15 @@ class SNLIClassifierTrainer(object):
   def learning_rate(self):
     return self._learning_rate
 
+  @property
+  def model(self):
+    return self._model
+
+  @property
+  def variables(self):
+    return (self._model.variables + [self.learning_rate] +
+            self._optimizer.variables())
+
 
 def _batch_n_correct(logits, label):
   """Calculate number of correct predictions in a batch.
@@ -488,13 +498,12 @@ def _batch_n_correct(logits, label):
           tf.argmax(logits, axis=1), label)), tf.float32)).numpy()
 
 
-def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
+def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
   """Run evaluation on a dataset.
 
   Args:
     snli_data: The `data.SnliData` to use in this evaluation.
     batch_size: The batch size to use during this evaluation.
-    model: An instance of `SNLIClassifier` to evaluate.
     trainer: An instance of `SNLIClassifierTrainer to use for this
       evaluation.
     use_gpu: Whether GPU is being used.
@@ -509,7 +518,7 @@ def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
       snli_data, batch_size):
     if use_gpu:
       label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
-    logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+    logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
     loss_val = trainer.loss(label, logits)
     batch_size = tf.shape(label)[0]
     mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -536,13 +545,19 @@ def _get_dataset_iterator(snli_data, batch_size):
     return tfe.Iterator(dataset)
 
 
-def train_spinn(embed, train_data, dev_data, test_data, config):
-  """Train a SPINN model.
+def train_or_infer_spinn(embed,
+                         word2index,
+                         train_data,
+                         dev_data,
+                         test_data,
+                         config):
+  """Perform Training or Inference on a SPINN model.
 
   Args:
     embed: The embedding matrix as a float32 numpy array with shape
       [vocabulary_size, word_vector_len]. word_vector_len is the length of a
       word embedding vector.
+    word2index: A `dict` mapping word to word index.
     train_data: An instance of `data.SnliData`, for the train split.
     dev_data: Same as above, for the dev split.
     test_data: Same as above, for the test split.
@@ -550,13 +565,35 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
       details.
 
   Returns:
-    1. Final loss value on the test split.
-    2. Final fraction of correct classifications on the test split.
+    If `config.inference_premise ` and `config.inference_hypothesis` are not
+      `None`, i.e., inference mode: the logits for the possible labels of the
+      SNLI data set, as numpy array of three floats.
+    else:
+      The trainer object.
+  Raises:
+    ValueError: if only one of config.inference_premise and
+      config.inference_hypothesis is specified.
   """
+  # TODO(cais): Refactor this function into separate one for training and
+  #   inference.
   use_gpu = tfe.num_gpus() > 0 and not config.force_cpu
   device = "gpu:0" if use_gpu else "cpu:0"
   print("Using device: %s" % device)
 
+  if ((config.inference_premise and not config.inference_hypothesis) or
+      (not config.inference_premise and config.inference_hypothesis)):
+    raise ValueError(
+        "--inference_premise and --inference_hypothesis must be both "
+        "specified or both unspecified, but only one is specified.")
+
+  if config.inference_premise:
+    # Inference mode.
+    inference_sentence_pair = [
+        data.encode_sentence(config.inference_premise, word2index),
+        data.encode_sentence(config.inference_hypothesis, word2index)]
+  else:
+    inference_sentence_pair = None
+
   log_header = (
       "  Time Epoch Iteration Progress    (%Epoch)   Loss   Dev/Loss"
       "     Accuracy  Dev/Accuracy")
@@ -569,16 +606,36 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
 
   summary_writer = tf.contrib.summary.create_file_writer(
       config.logdir, flush_millis=10000)
-  train_len = train_data.num_batches(config.batch_size)
+
   with tf.device(device), \
-       tfe.restore_variables_on_create(
-           tf.train.latest_checkpoint(config.logdir)), \
        summary_writer.as_default(), \
        tf.contrib.summary.always_record_summaries():
-    model = SNLIClassifier(config, embed)
-    global_step = tf.train.get_or_create_global_step()
-    trainer = SNLIClassifierTrainer(model, config.lr)
-
+    with tfe.restore_variables_on_create(
+        tf.train.latest_checkpoint(config.logdir)):
+      model = SNLIClassifier(config, embed)
+      global_step = tf.train.get_or_create_global_step()
+      trainer = SNLIClassifierTrainer(model, config.lr)
+
+    if inference_sentence_pair:
+      # Inference mode.
+      with tfe.restore_variables_on_create(
+          tf.train.latest_checkpoint(config.logdir)):
+        prem, prem_trans = inference_sentence_pair[0]
+        hypo, hypo_trans = inference_sentence_pair[1]
+        hypo_trans = inference_sentence_pair[1][1]
+        inference_logits = model(  # pylint: disable=not-callable
+            tf.constant(prem), tf.constant(prem_trans),
+            tf.constant(hypo), tf.constant(hypo_trans), training=False)
+        inference_logits = np.array(inference_logits[0][1:])
+        max_index = np.argmax(inference_logits)
+        print("\nInference logits:")
+        for i, (label, logit) in enumerate(
+            zip(data.POSSIBLE_LABELS, inference_logits)):
+          winner_tag = " (winner)" if max_index == i else ""
+          print("  {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
+      return inference_logits
+
+    train_len = train_data.num_batches(config.batch_size)
     start = time.time()
     iterations = 0
     mean_loss = tfe.metrics.Mean()
@@ -594,23 +651,24 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
           # remain on CPU. Same in _evaluate_on_dataset().
 
         iterations += 1
-        batch_train_loss, batch_train_logits = trainer.train_batch(
-            label, prem, prem_trans, hypo, hypo_trans)
+        with tfe.restore_variables_on_create(
+            tf.train.latest_checkpoint(config.logdir)):
+          batch_train_loss, batch_train_logits = trainer.train_batch(
+              label, prem, prem_trans, hypo, hypo_trans)
         batch_size = tf.shape(label)[0]
         mean_loss(batch_train_loss.numpy(),
                   weights=batch_size.gpu() if use_gpu else batch_size)
         accuracy(tf.argmax(batch_train_logits, axis=1), label)
 
         if iterations % config.save_every == 0:
-          all_variables = (
-              model.variables + [trainer.learning_rate] + [global_step])
+          all_variables = trainer.variables + [global_step]
           saver = tfe.Saver(all_variables)
           saver.save(os.path.join(config.logdir, "ckpt"),
                      global_step=global_step)
 
         if iterations % config.dev_every == 0:
           dev_loss, dev_frac_correct = _evaluate_on_dataset(
-              dev_data, config.batch_size, model, trainer, use_gpu)
+              dev_data, config.batch_size, trainer, use_gpu)
           print(dev_log_template.format(
               time.time() - start,
               epoch, iterations, 1 + batch_idx, train_len,
@@ -638,10 +696,12 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
         trainer.decay_learning_rate(config.lr_decay_by)
 
     test_loss, test_frac_correct = _evaluate_on_dataset(
-        test_data, config.batch_size, model, trainer, use_gpu)
+        test_data, config.batch_size, trainer, use_gpu)
     print("Final test loss: %g; accuracy: %g%%" %
           (test_loss, test_frac_correct * 100.0))
 
+  return trainer
+
 
 def main(_):
   config = FLAGS
@@ -650,18 +710,24 @@ def main(_):
   vocab = data.load_vocabulary(FLAGS.data_root)
   word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)
 
-  print("Loading train, dev and test data...")
-  train_data = data.SnliData(
-      os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
-      word2index, sentence_len_limit=FLAGS.sentence_len_limit)
-  dev_data = data.SnliData(
-      os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
-      word2index, sentence_len_limit=FLAGS.sentence_len_limit)
-  test_data = data.SnliData(
-      os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
-      word2index, sentence_len_limit=FLAGS.sentence_len_limit)
-
-  train_spinn(embed, train_data, dev_data, test_data, config)
+  if not (config.inference_premise or config.inference_hypothesis):
+    print("Loading train, dev and test data...")
+    train_data = data.SnliData(
+        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
+        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+    dev_data = data.SnliData(
+        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
+        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+    test_data = data.SnliData(
+        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
+        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+  else:
+    train_data = None
+    dev_data = None
+    test_data = None
+
+  train_or_infer_spinn(
+      embed, word2index, train_data, dev_data, test_data, config)
 
 
 if __name__ == "__main__":
@@ -678,6 +744,15 @@ if __name__ == "__main__":
   parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs",
                       help="Directory in which summaries will be written for "
                       "TensorBoard.")
+  parser.add_argument("--inference_premise", type=str, default=None,
+                      help="Premise sentence for inference. Must be "
+                      "accompanied by --inference_hypothesis. If specified, "
+                      "will override all training parameters and perform "
+                      "inference.")
+  parser.add_argument("--inference_hypothesis", type=str, default=None,
+                      help="Hypothesis sentence for inference. Must be "
+                      "accompanied by --inference_premise. If specified, will "
+                      "override all training parameters and perform inference.")
   parser.add_argument("--epochs", type=int, default=50,
                       help="Number of epochs to train.")
   parser.add_argument("--batch_size", type=int, default=128,