Make bijectors/reshape_test.py work with C API enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 23 Jan 2018 22:54:22 +0000 (14:54 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 23 Jan 2018 22:57:52 +0000 (14:57 -0800)
Some of the error messages change with the C API enabled due to slight
differences in shape inference (in this case, the C API catches shape
errors sooner). This change refactors some of the tests to expect
different error messages in different cases.

PiperOrigin-RevId: 182997183

tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py

index 49451446b56d290f130c5db90c13b94974d92dc9..e216d88cb190dc16fc0056186f80817d6f2d7c67 100644 (file)
@@ -22,12 +22,15 @@ import numpy as np
 
 from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
 from tensorflow.python.platform import test
 
 
+@test_util.with_c_api
 class _ReshapeBijectorTest(object):
   """Base class for testing the reshape transformation.
 
@@ -136,7 +139,8 @@ class _ReshapeBijectorTest(object):
         sess.run(bijector.forward_event_shape_tensor(shape_in),
                  feed_dict=feed_dict)
 
-  def testInvalidDimensionsOpError(self):
+  # pylint: disable=invalid-name
+  def _testInvalidDimensionsOpError(self, expected_error_message):
 
     with self.test_session() as sess:
 
@@ -146,10 +150,10 @@ class _ReshapeBijectorTest(object):
           event_shape_in=shape_in,
           validate_args=True)
 
-      with self.assertRaisesError(
-          "elements must be either positive integers or `-1`."):
+      with self.assertRaisesError(expected_error_message):
         sess.run(bijector.forward_event_shape_tensor(shape_in),
                  feed_dict=feed_dict)
+  # pylint: enable=invalid-name
 
   def testValidButNonMatchingInputOpError(self):
     x = np.random.randn(4, 3, 2)
@@ -184,7 +188,8 @@ class _ReshapeBijectorTest(object):
         sess.run(bijector.forward(x),
                  feed_dict=feed_dict)
 
-  def testInputOutputMismatchOpError(self):
+  # pylint: disable=invalid-name
+  def _testInputOutputMismatchOpError(self, expected_error_message):
     x1 = np.random.randn(4, 2, 3)
     x2 = np.random.randn(4, 1, 1, 5)
 
@@ -196,13 +201,11 @@ class _ReshapeBijectorTest(object):
           event_shape_in=shape_in,
           validate_args=True)
 
-      # test that *all* methods check basic assertions
-      with self.assertRaisesError(
-          "Input to reshape is a tensor with"):
+      with self.assertRaisesError(expected_error_message):
         sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
-      with self.assertRaisesError(
-          "Input to reshape is a tensor with"):
+      with self.assertRaisesError(expected_error_message):
         sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
+  # pylint: enable=invalid-name
 
   def testOneShapePartiallySpecified(self):
     expected_x = np.random.randn(4, 6)
@@ -262,6 +265,7 @@ class _ReshapeBijectorTest(object):
     raise NotImplementedError("Subclass failed to implement `build_shapes`.")
 
 
+@test_util.with_c_api
 class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
 
   def build_shapes(self, shape_in, shape_out):
@@ -299,7 +303,22 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
           validate_args=True)
       assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
 
+  def testInvalidDimensionsOpError(self):
+    if ops._USE_C_API:
+      error_message = "Invalid value in tensor used for shape: -2"
+    else:
+      error_message = "elements must be either positive integers or `-1`."
+    self._testInvalidDimensionsOpError(error_message)
+
+  def testInputOutputMismatchOpError(self):
+    if ops._USE_C_API:
+      error_message = "Cannot reshape a tensor with"
+    else:
+      error_message = "Input to reshape is a tensor with"
+    self._testInputOutputMismatchOpError(error_message)
+
 
+@test_util.with_c_api
 class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest):
 
   def build_shapes(self, shape_in, shape_out):
@@ -313,7 +332,15 @@ class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest):
   def assertRaisesError(self, msg):
     return self.assertRaisesOpError(msg)
 
+  def testInvalidDimensionsOpError(self):
+    self._testInvalidDimensionsOpError(
+        "elements must be either positive integers or `-1`.")
+
+  def testInputOutputMismatchOpError(self):
+    self._testInputOutputMismatchOpError("Input to reshape is a tensor with")
+
 
+@test_util.with_c_api
 class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest):
 
   def build_shapes(self, shape_in, shape_out):
@@ -325,6 +352,13 @@ class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest):
   def assertRaisesError(self, msg):
     return self.assertRaisesOpError(msg)
 
+  def testInvalidDimensionsOpError(self):
+    self._testInvalidDimensionsOpError(
+        "elements must be either positive integers or `-1`.")
+
+  def testInputOutputMismatchOpError(self):
+    self._testInputOutputMismatchOpError("Input to reshape is a tensor with")
+
 
 if __name__ == "__main__":
   test.main()