From: Shanqing Cai Date: Wed, 14 Feb 2018 05:49:28 +0000 (-0800) Subject: tfe SPINN example: Add inference; fix serialization X-Git-Tag: upstream/v1.7.0~31^2~713 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=14b6365a36c8982092ed2010e1e90f66f663deeb;p=platform%2Fupstream%2Ftensorflow.git tfe SPINN example: Add inference; fix serialization * Also de-flake a test. PiperOrigin-RevId: 185637742 --- diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 21055cf..a1f8a75 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -38,9 +38,5 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], - tags = [ - "manual", - "no_gpu", - "no_pip", # because spinn.py is under third_party/. - ], + tags = ["no_pip"], # because spinn.py is under third_party/. ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py index fcaae0a..3bc3bb4 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -227,6 +227,29 @@ def calculate_bins(length2count, min_bin_size): return bounds +def encode_sentence(sentence, word2index): + """Encode a single sentence as word indices and shift-reduce code. + + Args: + sentence: The sentence with added binary parse information, represented as + a string, with all the word items and parentheses separated by spaces. + E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'. + word2index: A `dict` mapping words to their word indices. + + Returns: + 1. Word indices as a numpy array, with shape `(sequence_len, 1)`. + 2. Shift-reduce sequence as a numpy array, with shape + `(sequence_len * 2 - 3, 1)`. + """ + items = [w for w in sentence.split(" ") if w] + words = get_non_parenthesis_words(items) + shift_reduce = get_shift_reduce(items) + word_indices = pad_and_reverse_word_ids( + [[word2index.get(word, UNK_CODE) for word in words]]).T + return (word_indices, + np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1)) + + class SnliData(object): """A split of SNLI data.""" diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py index e4f0b37..54fef2c 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -22,6 +22,7 @@ import os import shutil import tempfile +import numpy as np import tensorflow as tf from tensorflow.contrib.eager.python.examples.spinn import data @@ -173,14 +174,9 @@ class DataTest(tf.test.TestCase): ValueError, "Cannot find GloVe embedding file at"): data.load_word_vectors(self._temp_data_dir, vocab) - def testSnliData(self): - """Unit test for SnliData objects.""" - snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") - fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") - os.makedirs(snli_1_0_dir) - + def _createFakeSnliData(self, fake_snli_file): # Four sentences in total. - with open(fake_train_file, "wt") as f: + with open(fake_snli_file, "wt") as f: f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") @@ -205,10 +201,7 @@ class DataTest(tf.test.TestCase): "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") - glove_dir = os.path.join(self._temp_data_dir, "glove") - os.makedirs(glove_dir) - glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") - + def _createFakeGloveData(self, glove_file): words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] with open(glove_file, "wt") as f: for i, word in enumerate(words): @@ -220,6 +213,40 @@ class DataTest(tf.test.TestCase): else: f.write("\n") + def testEncodeSingleSentence(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + vocab = data.load_vocabulary(self._temp_data_dir) + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + sentence_variants = [ + "( Foo ( ( bar baz ) . ) )", + " ( Foo ( ( bar baz ) . ) ) ", + "( Foo ( ( bar baz ) . ) )"] + for sentence in sentence_variants: + word_indices, shift_reduce = data.encode_sentence(sentence, word2index) + self.assertEqual(np.int64, word_indices.dtype) + self.assertEqual((5, 1), word_indices.shape) + self.assertAllClose( + np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce) + + def testSnliData(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + vocab = data.load_vocabulary(self._temp_data_dir) word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) @@ -230,7 +257,7 @@ class DataTest(tf.test.TestCase): self.assertEqual(1, train_data.num_batches(4)) generator = train_data.get_generator(2)() - for i in range(2): + for _ in range(2): label, prem, prem_trans, hypo, hypo_trans = next(generator) self.assertEqual(2, len(label)) self.assertEqual((4, 2), prem.shape) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 7b2f09c..eefc06d 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,6 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util +from tensorflow.python.training import checkpoint_utils # pylint: enable=g-bad-import-order @@ -66,13 +67,30 @@ def _generate_synthetic_snli_data_batch(sequence_length, return labels, prem, prem_trans, hypo, hypo_trans -def _test_spinn_config(d_embed, d_out, logdir=None): +def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None): + """Generate a config tuple for testing. + + Args: + d_embed: Embedding dimensions. + d_out: Model output dimensions. + logdir: Optional logdir. + inference_sentences: A 2-tuple of strings representing the sentences (with + binary parsing result), e.g., + ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )"). + + Returns: + A config tuple. + """ config_tuple = collections.namedtuple( "Config", ["d_hidden", "d_proj", "d_tracker", "predict", "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", "d_out", "projection", "lr", "batch_size", "epochs", "force_cpu", "logdir", "log_every", "dev_every", "save_every", - "lr_decay_every", "lr_decay_by"]) + "lr_decay_every", "lr_decay_by", "inference_premise", + "inference_hypothesis"]) + + inference_premise = inference_sentences[0] if inference_sentences else None + inference_hypothesis = inference_sentences[1] if inference_sentences else None return config_tuple( d_hidden=d_embed, d_proj=d_embed * 2, @@ -86,14 +104,16 @@ def _test_spinn_config(d_embed, d_out, logdir=None): projection=True, lr=2e-2, batch_size=2, - epochs=10, + epochs=20, force_cpu=False, logdir=logdir, log_every=1, dev_every=2, save_every=2, lr_decay_every=1, - lr_decay_by=0.75) + lr_decay_by=0.75, + inference_premise=inference_premise, + inference_hypothesis=inference_hypothesis) class SpinnTest(test_util.TensorFlowTestCase): @@ -288,11 +308,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # Training on the batch should have led to a change in the loss value. self.assertNotEqual(loss1.numpy(), loss2.numpy()) - def testTrainSpinn(self): - """Test with fake toy SNLI data and GloVe vectors.""" - - # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. - snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + def _create_test_data(self, snli_1_0_dir): fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") os.makedirs(snli_1_0_dir) @@ -337,13 +353,52 @@ class SpinnTest(test_util.TensorFlowTestCase): else: f.write("\n") + return fake_train_file + + def testInferSpinnWorks(self): + """Test inference with the spinn model.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )")) + logits = spinn.train_or_infer_spinn( + embed, word2index, None, None, None, config) + self.assertEqual(np.float32, logits.dtype) + self.assertEqual((3,), logits.shape) + + def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", None)) + with self.assertRaises(ValueError): + spinn.train_or_infer_spinn(embed, word2index, None, None, None, config) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = self._create_test_data(snli_1_0_dir) + vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) - print(embed) # 2. Create a fake config. config = _test_spinn_config( @@ -351,7 +406,8 @@ class SpinnTest(test_util.TensorFlowTestCase): logdir=os.path.join(self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. - spinn.train_spinn(embed, train_data, dev_data, test_data, config) + trainer = spinn.train_or_infer_spinn( + embed, word2index, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. @@ -363,6 +419,15 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual(config.epochs, len(train_losses)) self.assertLess(train_losses[-1], train_losses[0]) + # 5. Verify that checkpoints exist and contains all the expected variables. + self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) + ckpt_variable_names = [ + item[0] for item in checkpoint_utils.list_variables(config.logdir)] + self.assertIn("global_step", ckpt_variable_names) + for v in trainer.variables: + variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name + self.assertIn(variable_name, ckpt_variable_names) + class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md index 6bd3d53..335c0fa 100644 --- a/third_party/examples/eager/spinn/README.md +++ b/third_party/examples/eager/spinn/README.md @@ -66,3 +66,44 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth ```bash tensorboard --logdir /tmp/spinn-logs ``` + +- After training, you may use the model to perform inference on input data in + the SNLI data format. The premise and hypotheses sentences are specified with + the command-line flags `--inference_premise` and `--inference_hypothesis`, + respecitvely. Each sentence should include the words, as well as parentheses + representing a binary parsing of the sentence. The words and parentheses + should all be separated by spaces. For instance, + + ```bash + pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ + --inference_premise '( ( The dog ) ( ( is running ) . ) )' \ + --inference_hypothesis '( ( The dog ) ( moves . ) )' + ``` + + which will generate an output like the following, due to the semantic + consistency of the two sentences. + + ```none + Inference logits: + entailment: 1.101249 (winner) + contradiction: -2.374171 + neutral: -0.296733 + ``` + + By contrast, the following sentence pair: + + ```bash + pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ + --inference_premise '( ( The dog ) ( ( is running ) . ) )' \ + --inference_hypothesis '( ( The dog ) ( rests . ) )' + ``` + + will give you an output like the following, due to the semantic + contradiction of the two sentences. + + ```none + Inference logits: + entailment: -1.070098 + contradiction: 2.798695 (winner) + neutral: -1.402287 + ``` diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index a2fa18e..38ba48d 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -44,6 +44,7 @@ import os import sys import time +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -471,6 +472,15 @@ class SNLIClassifierTrainer(object): def learning_rate(self): return self._learning_rate + @property + def model(self): + return self._model + + @property + def variables(self): + return (self._model.variables + [self.learning_rate] + + self._optimizer.variables()) + def _batch_n_correct(logits, label): """Calculate number of correct predictions in a batch. @@ -488,13 +498,12 @@ def _batch_n_correct(logits, label): tf.argmax(logits, axis=1), label)), tf.float32)).numpy() -def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu): +def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): """Run evaluation on a dataset. Args: snli_data: The `data.SnliData` to use in this evaluation. batch_size: The batch size to use during this evaluation. - model: An instance of `SNLIClassifier` to evaluate. trainer: An instance of `SNLIClassifierTrainer to use for this evaluation. use_gpu: Whether GPU is being used. @@ -509,7 +518,7 @@ def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -536,13 +545,19 @@ def _get_dataset_iterator(snli_data, batch_size): return tfe.Iterator(dataset) -def train_spinn(embed, train_data, dev_data, test_data, config): - """Train a SPINN model. +def train_or_infer_spinn(embed, + word2index, + train_data, + dev_data, + test_data, + config): + """Perform Training or Inference on a SPINN model. Args: embed: The embedding matrix as a float32 numpy array with shape [vocabulary_size, word_vector_len]. word_vector_len is the length of a word embedding vector. + word2index: A `dict` mapping word to word index. train_data: An instance of `data.SnliData`, for the train split. dev_data: Same as above, for the dev split. test_data: Same as above, for the test split. @@ -550,13 +565,35 @@ def train_spinn(embed, train_data, dev_data, test_data, config): details. Returns: - 1. Final loss value on the test split. - 2. Final fraction of correct classifications on the test split. + If `config.inference_premise ` and `config.inference_hypothesis` are not + `None`, i.e., inference mode: the logits for the possible labels of the + SNLI data set, as numpy array of three floats. + else: + The trainer object. + Raises: + ValueError: if only one of config.inference_premise and + config.inference_hypothesis is specified. """ + # TODO(cais): Refactor this function into separate one for training and + # inference. use_gpu = tfe.num_gpus() > 0 and not config.force_cpu device = "gpu:0" if use_gpu else "cpu:0" print("Using device: %s" % device) + if ((config.inference_premise and not config.inference_hypothesis) or + (not config.inference_premise and config.inference_hypothesis)): + raise ValueError( + "--inference_premise and --inference_hypothesis must be both " + "specified or both unspecified, but only one is specified.") + + if config.inference_premise: + # Inference mode. + inference_sentence_pair = [ + data.encode_sentence(config.inference_premise, word2index), + data.encode_sentence(config.inference_hypothesis, word2index)] + else: + inference_sentence_pair = None + log_header = ( " Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss" " Accuracy Dev/Accuracy") @@ -569,16 +606,36 @@ def train_spinn(embed, train_data, dev_data, test_data, config): summary_writer = tf.contrib.summary.create_file_writer( config.logdir, flush_millis=10000) - train_len = train_data.num_batches(config.batch_size) + with tf.device(device), \ - tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)), \ summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): - model = SNLIClassifier(config, embed) - global_step = tf.train.get_or_create_global_step() - trainer = SNLIClassifierTrainer(model, config.lr) - + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + model = SNLIClassifier(config, embed) + global_step = tf.train.get_or_create_global_step() + trainer = SNLIClassifierTrainer(model, config.lr) + + if inference_sentence_pair: + # Inference mode. + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + prem, prem_trans = inference_sentence_pair[0] + hypo, hypo_trans = inference_sentence_pair[1] + hypo_trans = inference_sentence_pair[1][1] + inference_logits = model( # pylint: disable=not-callable + tf.constant(prem), tf.constant(prem_trans), + tf.constant(hypo), tf.constant(hypo_trans), training=False) + inference_logits = np.array(inference_logits[0][1:]) + max_index = np.argmax(inference_logits) + print("\nInference logits:") + for i, (label, logit) in enumerate( + zip(data.POSSIBLE_LABELS, inference_logits)): + winner_tag = " (winner)" if max_index == i else "" + print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) + return inference_logits + + train_len = train_data.num_batches(config.batch_size) start = time.time() iterations = 0 mean_loss = tfe.metrics.Mean() @@ -594,23 +651,24 @@ def train_spinn(embed, train_data, dev_data, test_data, config): # remain on CPU. Same in _evaluate_on_dataset(). iterations += 1 - batch_train_loss, batch_train_logits = trainer.train_batch( - label, prem, prem_trans, hypo, hypo_trans) + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + batch_train_loss, batch_train_logits = trainer.train_batch( + label, prem, prem_trans, hypo, hypo_trans) batch_size = tf.shape(label)[0] mean_loss(batch_train_loss.numpy(), weights=batch_size.gpu() if use_gpu else batch_size) accuracy(tf.argmax(batch_train_logits, axis=1), label) if iterations % config.save_every == 0: - all_variables = ( - model.variables + [trainer.learning_rate] + [global_step]) + all_variables = trainer.variables + [global_step] saver = tfe.Saver(all_variables) saver.save(os.path.join(config.logdir, "ckpt"), global_step=global_step) if iterations % config.dev_every == 0: dev_loss, dev_frac_correct = _evaluate_on_dataset( - dev_data, config.batch_size, model, trainer, use_gpu) + dev_data, config.batch_size, trainer, use_gpu) print(dev_log_template.format( time.time() - start, epoch, iterations, 1 + batch_idx, train_len, @@ -638,10 +696,12 @@ def train_spinn(embed, train_data, dev_data, test_data, config): trainer.decay_learning_rate(config.lr_decay_by) test_loss, test_frac_correct = _evaluate_on_dataset( - test_data, config.batch_size, model, trainer, use_gpu) + test_data, config.batch_size, trainer, use_gpu) print("Final test loss: %g; accuracy: %g%%" % (test_loss, test_frac_correct * 100.0)) + return trainer + def main(_): config = FLAGS @@ -650,18 +710,24 @@ def main(_): vocab = data.load_vocabulary(FLAGS.data_root) word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab) - print("Loading train, dev and test data...") - train_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - dev_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - test_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - - train_spinn(embed, train_data, dev_data, test_data, config) + if not (config.inference_premise or config.inference_hypothesis): + print("Loading train, dev and test data...") + train_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + dev_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + test_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + else: + train_data = None + dev_data = None + test_data = None + + train_or_infer_spinn( + embed, word2index, train_data, dev_data, test_data, config) if __name__ == "__main__": @@ -678,6 +744,15 @@ if __name__ == "__main__": parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs", help="Directory in which summaries will be written for " "TensorBoard.") + parser.add_argument("--inference_premise", type=str, default=None, + help="Premise sentence for inference. Must be " + "accompanied by --inference_hypothesis. If specified, " + "will override all training parameters and perform " + "inference.") + parser.add_argument("--inference_hypothesis", type=str, default=None, + help="Hypothesis sentence for inference. Must be " + "accompanied by --inference_premise. If specified, will " + "override all training parameters and perform inference.") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs to train.") parser.add_argument("--batch_size", type=int, default=128,