Add python built-in types support for `tf.as_dtype` (#17652)
authorYong Tang <yong.tang.github@outlook.com>
Tue, 10 Apr 2018 17:34:32 +0000 (10:34 -0700)
committerDerek Murray <derek.murray@gmail.com>
Tue, 10 Apr 2018 17:34:32 +0000 (10:34 -0700)
* Add python built-in types support for `tf.as_dtype`

This fix tries to address the issue raised in 17641 where
it was not possible to use `tf.as_dtype(float)` the same
way as numpy `np.dtype(float)`.
This fix adds the built-in types support for `tf.as_dtype`,
so that it is possible to specify:
```
dtypes.as_dtype(float)   # dtypes.float64
dtypes.as_dtype(int)     # dtypes.int32
dtypes.as_dtype(long)    # dtypes.int64
dtypes.as_dtype(complex) # dtypes.complex128
dtypes.as_dtype(bool)    # dtypes.bool
```

This fix fixes 17641.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test cases for built-in types support with `tf.as_dtype`

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Fix failed test cases with added built-in types support of tf.as_dtype

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Fix python 3 build

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Restrict the changes to float and bool based on review feedback

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/python/framework/dtypes.py
tensorflow/python/framework/dtypes_test.py

index a31c424..6d918f8 100644 (file)
@@ -648,6 +648,10 @@ QUANTIZED_DTYPES = frozenset([
 ])
 tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
 
+_PYTHON_TO_TF = {
+    float: float32,
+    bool: bool,
+}
 
 @tf_export("as_dtype")
 def as_dtype(type_value):
@@ -679,6 +683,11 @@ def as_dtype(type_value):
   except KeyError:
     pass
 
+  try:
+    return _PYTHON_TO_TF[type_value]
+  except KeyError:
+    pass
+
   if isinstance(type_value, np.dtype):
     # The numpy dtype for strings is variable length. We can not compare
     # dtype with a single constant (np.string does not exist) to decide
index e49e2fd..478733e 100644 (file)
@@ -295,6 +295,9 @@ class TypesTest(test_util.TensorFlowTestCase):
     self.assertNotEqual(dtypes.int32, int)
     self.assertNotEqual(dtypes.float64, 2.1)
 
+  def testPythonTypesConversion(self):
+    self.assertIs(dtypes.float32, dtypes.as_dtype(float))
+    self.assertIs(dtypes.bool, dtypes.as_dtype(bool))
 
 if __name__ == "__main__":
   googletest.main()