inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )"))
logits = spinn.train_or_infer_spinn(
embed, word2index, None, None, None, config)
- self.assertEqual(np.float32, logits.dtype)
+ self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((3,), logits.shape)
def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
should all be separated by spaces. For instance,
```bash
- pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
--inference_premise '( ( The dog ) ( ( is running ) . ) )' \
--inference_hypothesis '( ( The dog ) ( moves . ) )'
```
By contrast, the following sentence pair:
```bash
- pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
--inference_premise '( ( The dog ) ( ( is running ) . ) )' \
--inference_hypothesis '( ( The dog ) ( rests . ) )'
```
import sys
import time
-import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
Returns:
If `config.inference_premise ` and `config.inference_hypothesis` are not
`None`, i.e., inference mode: the logits for the possible labels of the
- SNLI data set, as numpy array of three floats.
+ SNLI data set, as a `Tensor` of three floats.
else:
The trainer object.
Raises:
inference_logits = model( # pylint: disable=not-callable
tf.constant(prem), tf.constant(prem_trans),
tf.constant(hypo), tf.constant(hypo_trans), training=False)
- inference_logits = np.array(inference_logits[0][1:])
- max_index = np.argmax(inference_logits)
+ inference_logits = inference_logits[0][1:]
+ max_index = tf.argmax(inference_logits)
print("\nInference logits:")
for i, (label, logit) in enumerate(
zip(data.POSSIBLE_LABELS, inference_logits)):