tf.Dimension raises TypeError for tf.DType (#17086)
authorYan Facai (颜发才) <facai.yan@gmail.com>
Sat, 7 Apr 2018 03:05:10 +0000 (11:05 +0800)
committerdrpngx <drpngx@users.noreply.github.com>
Sat, 7 Apr 2018 03:05:10 +0000 (20:05 -0700)
* BUG: raise error for Dtype

* TST: add test case

tensorflow/python/BUILD
tensorflow/python/framework/tensor_shape.py
tensorflow/python/framework/tensor_shape_test.py

index a8f1318..753be82 100644 (file)
@@ -835,6 +835,7 @@ py_library(
     srcs = ["framework/tensor_shape.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":dtypes",
         ":util",
         "//tensorflow/core:protos_all_py",
     ],
index af2a5b1..26069d9 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import dtypes
 from tensorflow.python.util import compat
 from tensorflow.python.util.tf_export import tf_export
 
@@ -30,6 +31,8 @@ class Dimension(object):
     """Creates a new Dimension with the given value."""
     if value is None:
       self._value = None
+    elif isinstance(value, dtypes.DType):
+      raise TypeError("Cannot convert %s to Dimension" % value)
     else:
       self._value = int(value)
       if (not isinstance(value, compat.bytes_or_text_types) and
index 4e8ce4d..4f23922 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.platform import googletest
@@ -184,6 +185,10 @@ class DimensionTest(test_util.TensorFlowTestCase):
     self.assertEqual(str(tensor_shape.Dimension(7)), "7")
     self.assertEqual(str(tensor_shape.Dimension(None)), "?")
 
+  def testUnsupportedType(self):
+    with self.assertRaises(TypeError):
+      tensor_shape.Dimension(dtypes.string)
+      
   def testMod(self):
     four = tensor_shape.Dimension(4)
     nine = tensor_shape.Dimension(9)