Add test case for input shape of tf.roll
authorYong Tang <yong.tang.github@outlook.com>
Tue, 17 Apr 2018 00:55:37 +0000 (00:55 +0000)
committerYong Tang <yong.tang.github@outlook.com>
Tue, 17 Apr 2018 01:17:51 +0000 (01:17 +0000)
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/python/kernel_tests/manip_ops_test.py

index 7948a47..0ef02ea 100644 (file)
@@ -20,8 +20,10 @@ from __future__ import print_function
 import numpy as np
 
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import manip_ops
 from tensorflow.python.platform import test as test_lib
@@ -98,14 +100,20 @@ class RollTest(test_util.TensorFlowTestCase):
         manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
                        3, -10).eval()
 
+  def testInvalidInputShape(self):
+    # The input should be 1-D or higher, checked in shape function.
+    with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 1 but is rank 0"):
+      roll = manip_ops.roll(7, 1, 0)
+
   def testRollInputMustVectorHigherRaises(self):
-    tensor = 7
+    # The input should be 1-D or higher, checked is done in kernel.
+    tensor = array_ops.placeholder(dtype=dtypes.int32)
     shift = 1
     axis = 0
     with self.test_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "input must be 1-D or higher"):
-        manip_ops.roll(tensor, shift, axis).eval()
+        manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
 
   def testRollAxisMustBeScalarOrVectorRaises(self):
     tensor = [[1, 2], [3, 4]]