[XLA] Add more supported dtypes to the local Python client.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Feb 2018 00:01:04 +0000 (16:01 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187096144

tensorflow/compiler/xla/python/xla_client.py

index 3b8ec85..90cda42 100644 (file)
@@ -30,9 +30,9 @@ from tensorflow.compiler.xla import xla_data_pb2
 from tensorflow.compiler.xla.python import pywrap_xla as c_api
 
 
-# Most functions are snake_case for consistency with other modules,
-# whereas method names of ComputationBuilder and LocalComputation are
-# CamelCase for consistency with XLA.
+# Most functions are snake_case for consistency with other modules, whereas
+# method names of ComputationBuilder and LocalComputation are CamelCase for
+# consistency with XLA.
 # pylint: disable=invalid-name
 
 
@@ -123,24 +123,34 @@ _BINARY_OPS = [
     'Pow',
 ]
 
+
 XLA_ELEMENT_TYPE_TO_DTYPE = {
-    xla_data_pb2.F32: np.dtype(np.float32),
-    xla_data_pb2.F64: np.dtype(np.float64),
-    xla_data_pb2.S32: np.dtype(np.int32),
-    xla_data_pb2.S64: np.dtype(np.int64),
-    xla_data_pb2.U32: np.dtype(np.uint32),
-    xla_data_pb2.U64: np.dtype(np.uint64),
-    xla_data_pb2.PRED: np.dtype(np.bool),
+    xla_data_pb2.PRED: np.dtype('bool'),
+    xla_data_pb2.S8: np.dtype('int8'),
+    xla_data_pb2.S16: np.dtype('int16'),
+    xla_data_pb2.S32: np.dtype('int32'),
+    xla_data_pb2.S64: np.dtype('int64'),
+    xla_data_pb2.U8: np.dtype('uint8'),
+    xla_data_pb2.U16: np.dtype('uint16'),
+    xla_data_pb2.U32: np.dtype('uint32'),
+    xla_data_pb2.U64: np.dtype('uint64'),
+    xla_data_pb2.F16: np.dtype('float16'),
+    xla_data_pb2.F32: np.dtype('float32'),
+    xla_data_pb2.F64: np.dtype('float64'),
+    xla_data_pb2.C64: np.dtype('complex64'),
     xla_data_pb2.TUPLE: np.dtype(np.object),
 }
 
 # Note the conversion on the key. Numpy has a known issue wherein dtype hashing
 # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
 # when keying by dtype in this dict, we use the string form of dtypes.
-DTYPE_TO_XLA_ELEMENT_TYPE = {
-    str(v): k
-    for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items()
-}
+DTYPE_TO_XLA_ELEMENT_TYPE = {str(dt): et
+                             for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()}
+
+
+def dtype_to_etype(dtype):
+  """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE."""
+  return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
 
 
 class LocalBuffer(object):