from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.platform import test
+@test_util.with_c_api
class _ReshapeBijectorTest(object):
"""Base class for testing the reshape transformation.
sess.run(bijector.forward_event_shape_tensor(shape_in),
feed_dict=feed_dict)
- def testInvalidDimensionsOpError(self):
+ # pylint: disable=invalid-name
+ def _testInvalidDimensionsOpError(self, expected_error_message):
with self.test_session() as sess:
event_shape_in=shape_in,
validate_args=True)
- with self.assertRaisesError(
- "elements must be either positive integers or `-1`."):
+ with self.assertRaisesError(expected_error_message):
sess.run(bijector.forward_event_shape_tensor(shape_in),
feed_dict=feed_dict)
+ # pylint: enable=invalid-name
def testValidButNonMatchingInputOpError(self):
x = np.random.randn(4, 3, 2)
sess.run(bijector.forward(x),
feed_dict=feed_dict)
- def testInputOutputMismatchOpError(self):
+ # pylint: disable=invalid-name
+ def _testInputOutputMismatchOpError(self, expected_error_message):
x1 = np.random.randn(4, 2, 3)
x2 = np.random.randn(4, 1, 1, 5)
event_shape_in=shape_in,
validate_args=True)
- # test that *all* methods check basic assertions
- with self.assertRaisesError(
- "Input to reshape is a tensor with"):
+ with self.assertRaisesError(expected_error_message):
sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
- with self.assertRaisesError(
- "Input to reshape is a tensor with"):
+ with self.assertRaisesError(expected_error_message):
sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
+ # pylint: enable=invalid-name
def testOneShapePartiallySpecified(self):
expected_x = np.random.randn(4, 6)
raise NotImplementedError("Subclass failed to implement `build_shapes`.")
+@test_util.with_c_api
class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
def build_shapes(self, shape_in, shape_out):
validate_args=True)
assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
+ def testInvalidDimensionsOpError(self):
+ if ops._USE_C_API:
+ error_message = "Invalid value in tensor used for shape: -2"
+ else:
+ error_message = "elements must be either positive integers or `-1`."
+ self._testInvalidDimensionsOpError(error_message)
+
+ def testInputOutputMismatchOpError(self):
+ if ops._USE_C_API:
+ error_message = "Cannot reshape a tensor with"
+ else:
+ error_message = "Input to reshape is a tensor with"
+ self._testInputOutputMismatchOpError(error_message)
+
+@test_util.with_c_api
class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest):
def build_shapes(self, shape_in, shape_out):
def assertRaisesError(self, msg):
return self.assertRaisesOpError(msg)
+ def testInvalidDimensionsOpError(self):
+ self._testInvalidDimensionsOpError(
+ "elements must be either positive integers or `-1`.")
+
+ def testInputOutputMismatchOpError(self):
+ self._testInputOutputMismatchOpError("Input to reshape is a tensor with")
+
+@test_util.with_c_api
class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest):
def build_shapes(self, shape_in, shape_out):
def assertRaisesError(self, msg):
return self.assertRaisesOpError(msg)
+ def testInvalidDimensionsOpError(self):
+ self._testInvalidDimensionsOpError(
+ "elements must be either positive integers or `-1`.")
+
+ def testInputOutputMismatchOpError(self):
+ self._testInputOutputMismatchOpError("Input to reshape is a tensor with")
+
if __name__ == "__main__":
test.main()