* Add python built-in types support for `tf.as_dtype`
This fix tries to address the issue raised in 17641 where
it was not possible to use `tf.as_dtype(float)` the same
way as numpy `np.dtype(float)`.
This fix adds the built-in types support for `tf.as_dtype`,
so that it is possible to specify:
```
dtypes.as_dtype(float) # dtypes.float64
dtypes.as_dtype(int) # dtypes.int32
dtypes.as_dtype(long) # dtypes.int64
dtypes.as_dtype(complex) # dtypes.complex128
dtypes.as_dtype(bool) # dtypes.bool
```
This fix fixes 17641.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test cases for built-in types support with `tf.as_dtype`
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Fix failed test cases with added built-in types support of tf.as_dtype
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Fix python 3 build
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Restrict the changes to float and bool based on review feedback
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
])
tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
+_PYTHON_TO_TF = {
+ float: float32,
+ bool: bool,
+}
@tf_export("as_dtype")
def as_dtype(type_value):
except KeyError:
pass
+ try:
+ return _PYTHON_TO_TF[type_value]
+ except KeyError:
+ pass
+
if isinstance(type_value, np.dtype):
# The numpy dtype for strings is variable length. We can not compare
# dtype with a single constant (np.string does not exist) to decide
self.assertNotEqual(dtypes.int32, int)
self.assertNotEqual(dtypes.float64, 2.1)
+ def testPythonTypesConversion(self):
+ self.assertIs(dtypes.float32, dtypes.as_dtype(float))
+ self.assertIs(dtypes.bool, dtypes.as_dtype(bool))
if __name__ == "__main__":
googletest.main()