Make batch_sequences_with_states_test.py work with C API enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Thu, 25 Jan 2018 01:52:19 +0000 (17:52 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 25 Jan 2018 01:57:21 +0000 (17:57 -0800)
PiperOrigin-RevId: 183171572

tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py

index 2a0ef0e6b3750b4f0464f1f4390819e1fc2c7872..04538405e4bde7b89a5840f3486443d780e8b1d4 100644 (file)
@@ -320,6 +320,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
   def testNotAMultiple(self):
     num_unroll = 3  # Not a divisor of value_length -
     # so padding would have been necessary.
+
+    # Use placeholder_with_default in sequences to make sure we get runtime
+    # error instead of shape inference error
+    sequences = {
+        "seq1": array_ops.placeholder_with_default(self.sequences["seq1"],
+                                                   shape=(None, 5)),
+        "seq2": array_ops.placeholder_with_default(self.sequences["seq2"],
+                                                   shape=(None, 4, 2)),
+        "seq3": self.sequences["seq3"],
+        "seq4": self.sequences["seq4"],
+    }
+
     with self.test_session() as sess:
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    ".*should be a multiple of: 3, but saw "
@@ -330,7 +342,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
           with coord.stop_on_exception():
             next_batch = sqss.batch_sequences_with_states(
                 input_key=self.key,
-                input_sequences=self.sequences,
+                input_sequences=sequences,
                 input_context=self.context,
                 input_length=3,
                 initial_states=self.initial_states,
@@ -493,6 +505,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
         expected_seq4_batch2=expected_seq4_batch2)
 
 
+class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest):
+
+  def setUp(self):
+    self._prev_value = ops._USE_C_API
+    ops._USE_C_API = True
+    super(BatchSequencesWithStatesTestWithCApi, self).setUp()
+
+  def tearDown(self):
+    super(BatchSequencesWithStatesTestWithCApi, self).tearDown()
+    ops._USE_C_API = self._prev_value
+
+
 class PaddingTest(test.TestCase):
 
   def testPaddingInvalidLengths(self):