TFE SPINN example: use tensor instead of numpy array
authorShanqing Cai <cais@google.com>
Fri, 16 Feb 2018 03:12:05 +0000 (19:12 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Feb 2018 03:16:17 +0000 (19:16 -0800)
in inference output.

PiperOrigin-RevId: 185939805

tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
third_party/examples/eager/spinn/README.md
third_party/examples/eager/spinn/spinn.py

index eefc06d90d83b61d07a613643c913d3833a5f2c1..081b0af14fcc983a3f85d2a50e2bb04d2f2493b3 100644 (file)
@@ -369,7 +369,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
         inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )"))
     logits = spinn.train_or_infer_spinn(
         embed, word2index, None, None, None, config)
-    self.assertEqual(np.float32, logits.dtype)
+    self.assertEqual(tf.float32, logits.dtype)
     self.assertEqual((3,), logits.shape)
 
   def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
index 335c0fa3b549f6bc9221c81c5779cd499bd780d7..7f477d19208257474d0481ca04c04679f589b751 100644 (file)
@@ -75,7 +75,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth
   should all be separated by spaces. For instance,
 
   ```bash
-  pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+  python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
       --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
       --inference_hypothesis '( ( The dog ) ( moves . ) )'
   ```
@@ -93,7 +93,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth
   By contrast, the following sentence pair:
 
   ```bash
-  pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+  python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
       --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
       --inference_hypothesis '( ( The dog ) ( rests . ) )'
   ```
index 38ba48d5013c7515e7fc78de6125f0bd93fdc90a..8a1c7db2ea14365be53a796a79fce77900e668e1 100644 (file)
@@ -44,7 +44,6 @@ import os
 import sys
 import time
 
-import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import tensorflow as tf
 
@@ -567,7 +566,7 @@ def train_or_infer_spinn(embed,
   Returns:
     If `config.inference_premise ` and `config.inference_hypothesis` are not
       `None`, i.e., inference mode: the logits for the possible labels of the
-      SNLI data set, as numpy array of three floats.
+      SNLI data set, as a `Tensor` of three floats.
     else:
       The trainer object.
   Raises:
@@ -626,8 +625,8 @@ def train_or_infer_spinn(embed,
         inference_logits = model(  # pylint: disable=not-callable
             tf.constant(prem), tf.constant(prem_trans),
             tf.constant(hypo), tf.constant(hypo_trans), training=False)
-        inference_logits = np.array(inference_logits[0][1:])
-        max_index = np.argmax(inference_logits)
+        inference_logits = inference_logits[0][1:]
+        max_index = tf.argmax(inference_logits)
         print("\nInference logits:")
         for i, (label, logit) in enumerate(
             zip(data.POSSIBLE_LABELS, inference_logits)):