From: A. Unique TensorFlower Date: Tue, 27 Feb 2018 00:01:04 +0000 (-0800) Subject: [XLA] Add more supported dtypes to the local Python client. X-Git-Tag: upstream/v1.7.0~93 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6db1b213458ea7f0acd4476f70d930e15af8f35f;p=platform%2Fupstream%2Ftensorflow.git [XLA] Add more supported dtypes to the local Python client. PiperOrigin-RevId: 187096144 --- diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 3b8ec85..90cda42 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -30,9 +30,9 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api -# Most functions are snake_case for consistency with other modules, -# whereas method names of ComputationBuilder and LocalComputation are -# CamelCase for consistency with XLA. +# Most functions are snake_case for consistency with other modules, whereas +# method names of ComputationBuilder and LocalComputation are CamelCase for +# consistency with XLA. # pylint: disable=invalid-name @@ -123,24 +123,34 @@ _BINARY_OPS = [ 'Pow', ] + XLA_ELEMENT_TYPE_TO_DTYPE = { - xla_data_pb2.F32: np.dtype(np.float32), - xla_data_pb2.F64: np.dtype(np.float64), - xla_data_pb2.S32: np.dtype(np.int32), - xla_data_pb2.S64: np.dtype(np.int64), - xla_data_pb2.U32: np.dtype(np.uint32), - xla_data_pb2.U64: np.dtype(np.uint64), - xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.PRED: np.dtype('bool'), + xla_data_pb2.S8: np.dtype('int8'), + xla_data_pb2.S16: np.dtype('int16'), + xla_data_pb2.S32: np.dtype('int32'), + xla_data_pb2.S64: np.dtype('int64'), + xla_data_pb2.U8: np.dtype('uint8'), + xla_data_pb2.U16: np.dtype('uint16'), + xla_data_pb2.U32: np.dtype('uint32'), + xla_data_pb2.U64: np.dtype('uint64'), + xla_data_pb2.F16: np.dtype('float16'), + xla_data_pb2.F32: np.dtype('float32'), + xla_data_pb2.F64: np.dtype('float64'), + xla_data_pb2.C64: np.dtype('complex64'), xla_data_pb2.TUPLE: np.dtype(np.object), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(v): k - for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} +DTYPE_TO_XLA_ELEMENT_TYPE = {str(dt): et + for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] class LocalBuffer(object):