[TF] Add half precision to the supported data types for tensorflow operations.
authorBixia Zheng <bixia@google.com>
Fri, 6 Apr 2018 23:04:46 +0000 (16:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 23:07:34 +0000 (16:07 -0700)
Enable most of the half precision XLA compiler tests for the cpu backend,
except for two which are disabled and documented in a bug.

PiperOrigin-RevId: 191954183

tensorflow/compiler/tests/build_defs.bzl
tensorflow/compiler/tests/spacetobatch_op_test.py
tensorflow/compiler/tests/unary_ops_test.py
tensorflow/compiler/tf2xla/xla_helpers.cc
tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/math_ops.cc

index a9db1c1..45b6a6e 100644 (file)
@@ -51,7 +51,7 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
     if backend == "cpu":
       backend_args += [
           "--test_device=XLA_CPU",
-          "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
+          "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
       ]
     elif backend == "gpu":
       backend_args += [
index 6083981..ef47187 100644 (file)
@@ -163,6 +163,9 @@ class SpaceToBatchNDTest(XLATestCase):
         # error.
         if dtype == dtypes.bfloat16.as_numpy_dtype:
           continue
+        # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
+        if dtype == np.float16 and self.device == "XLA_CPU":
+          continue
         placeholder = array_ops.placeholder(dtype)
         # outputs = space_to_batch(inputs)
         x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings)
index 17149aa..ba79f39 100644 (file)
@@ -154,6 +154,9 @@ class UnaryOpsTest(XLATestCase):
 
   def testFloatOps(self):
     for dtype in self.float_types:
+      # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
+      if dtype == np.float16 and self.device == "XLA_CPU":
+        continue
       x = np.arange(-0.90, 0.90, 0.25)
       self._assertOpOutputMatchesExpected(
           math_ops.acos,
index 3b0b2f0..62a5114 100644 (file)
@@ -122,6 +122,9 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
 xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
                                                DataType data_type) {
   switch (data_type) {
+    case DT_HALF:
+      return b->ConstantR0<Eigen::half>(
+          static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
     case DT_BFLOAT16:
       return b->ConstantR0<bfloat16>(bfloat16::epsilon());
     case DT_FLOAT:
index 62ce70e..4b119e2 100644 (file)
@@ -622,7 +622,7 @@ REGISTER_OP("OnesLike")
     .Input("x: T")
     .Output("y: T")
     .Attr(
-        "T: {bfloat16, float, double, int8, uint8, int16, uint16, int32, "
+        "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
         "int64, complex64, complex128, bool}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
@@ -630,7 +630,9 @@ REGISTER_OP("OnesLike")
 REGISTER_OP("Diag")
     .Input("diagonal: T")
     .Output("output: T")
-    .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
+    .Attr(
+        "T: {bfloat16, half, float, double, int32, int64, complex64, "
+        "complex128}")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle in = c->input(0);
       TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
@@ -645,7 +647,9 @@ REGISTER_OP("Diag")
 REGISTER_OP("DiagPart")
     .Input("input: T")
     .Output("diagonal: T")
-    .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
+    .Attr(
+        "T: {bfloat16, half, float, double, int32, int64, complex64, "
+        "complex128}")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle in = c->input(0);
       if (!c->RankKnown(in)) {
@@ -789,7 +793,7 @@ REGISTER_OP("ReverseV2")
     .Output("output: T")
     .Attr("Tidx: {int32, int64} = DT_INT32")
     .Attr(
-        "T: {uint8, int8, uint16, int16, int32, int64, bool, half, bfloat16, "
+        "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
         "float, double, complex64, complex128, string}")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle input = c->input(0);
@@ -1165,7 +1169,7 @@ REGISTER_OP("PreventGradient")
 REGISTER_OP("CheckNumerics")
     .Input("tensor: T")
     .Output("output: T")
-    .Attr("T: {half, bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .Attr("message: string")
     .SetShapeFn(shape_inference::UnchangedShape);
 
@@ -2450,13 +2454,12 @@ REGISTER_OP("Bitcast")
     .Output("output: type")
     // All supported dtypes are listed here to include qint16 and quint16.
     .Attr(
-        "T: {bfloat16, float, double, int64, int32, uint8, uint16, int8, int16,"
-        " complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
-        " half}")
+        "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, int8, "
+        "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32}")
     .Attr(
-        "type: {bfloat16, float, double, int64, int32, uint8, uint16, int8, "
-        "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
-        " half}")
+        "type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
+        "int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, "
+        "qint32}")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle input = c->input(0);
       if (!c->RankKnown(input)) {
@@ -2552,7 +2555,7 @@ REGISTER_OP("QuantizeAndDequantize")
     .Attr("input_min: float = 0")
     .Attr("input_max: float = 0")
     .Output("output: T")
-    .Attr("T: {bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnchangedShape)
     .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
 
@@ -2565,7 +2568,7 @@ REGISTER_OP("QuantizeAndDequantizeV2")
     .Attr("num_bits: int = 8")
     .Attr("range_given: bool = false")
     .Output("output: T")
-    .Attr("T: {bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle unused;
       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -2582,7 +2585,7 @@ REGISTER_OP("QuantizeAndDequantizeV3")
     .Attr("signed_input: bool = true")
     .Attr("range_given: bool = true")
     .Output("output: T")
-    .Attr("T: {bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle unused;
       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
index 8f33d51..1180973 100644 (file)
@@ -65,7 +65,7 @@ REGISTER_OP("BatchMatMul")
     .Input("x: T")
     .Input("y: T")
     .Output("output: T")
-    .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
+    .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
     .Attr("adj_x: bool = false")
     .Attr("adj_y: bool = false")
     .SetShapeFn([](InferenceContext* c) {
@@ -133,7 +133,7 @@ _HostCast requires its input and produces its output in host memory.
 REGISTER_OP("Abs")
     .Input("x: T")
     .Output("y: T")
-    .Attr("T: {half, bfloat16, float, double, int32, int64}")
+    .Attr("T: {bfloat16, half, float, double, int32, int64}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 REGISTER_OP("ComplexAbs")
@@ -148,27 +148,27 @@ REGISTER_OP("ComplexAbs")
   Input("x: T")                                                          \
       .Output("y: T")                                                    \
       .Attr(                                                             \
-          "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+          "T: {bfloat16, half, float, double, int32, int64, complex64, " \
           "complex128}")                                                 \
       .SetShapeFn(shape_inference::UnchangedShape)
 
 #define UNARY_REAL()                              \
   Input("x: T")                                   \
       .Output("y: T")                             \
-      .Attr("T: {half, bfloat16, float, double}") \
+      .Attr("T: {bfloat16, half, float, double}") \
       .SetShapeFn(shape_inference::UnchangedShape)
 
 #define UNARY_COMPLEX()                                                  \
   Input("x: T")                                                          \
       .Output("y: T")                                                    \
-      .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
+      .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
       .SetShapeFn(shape_inference::UnchangedShape)
 
 #define UNARY_GRADIENT_COMPLEX()                                         \
   Input("y: T")                                                          \
       .Input("dy: T")                                                    \
       .Output("z: T")                                                    \
-      .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
+      .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
       .SetShapeFn(shape_inference::UnchangedShape)
 
 REGISTER_OP("Neg").UNARY();
@@ -246,57 +246,57 @@ REGISTER_OP("Atan").UNARY();
 REGISTER_OP("IsNan")
     .Input("x: T")
     .Output("y: bool")
-    .Attr("T: {half, bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 REGISTER_OP("IsInf")
     .Input("x: T")
     .Output("y: bool")
-    .Attr("T: {half, bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 REGISTER_OP("IsFinite")
     .Input("x: T")
     .Output("y: bool")
-    .Attr("T: {half, bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 REGISTER_OP("Sign")
     .Input("x: T")
     .Output("y: T")
     .Attr(
-        "T: {half, bfloat16, float, double, int32, int64, complex64, "
+        "T: {bfloat16, half, float, double, int32, int64, complex64, "
         "complex128}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 REGISTER_OP("Floor")
     .Input("x: T")
     .Output("y: T")
-    .Attr("T: {half, bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 REGISTER_OP("Ceil")
     .Input("x: T")
     .Output("y: T")
-    .Attr("T: {half, bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 REGISTER_OP("Rint")
     .Input("x: T")
     .Output("y: T")
-    .Attr("T: {bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 // Declares cwise binary operations signature: 't, 't -> 't.
 
 #define BINARY_MORE()                                                          \
   Input("x: T").Input("y: T").Output("z: T").Attr(                             \
-      "T: {half, bfloat16, float, double, uint8, int8, uint16, int16, int32, " \
+      "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
       "int64, complex64, complex128}")
 
 #define BINARY_FEWER()                                               \
   Input("x: T").Input("y: T").Output("z: T").Attr(                   \
-      "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+      "T: {bfloat16, half, float, double, int32, int64, complex64, " \
       "complex128}")
 
 REGISTER_OP("Add")
@@ -304,7 +304,7 @@ REGISTER_OP("Add")
     .Input("y: T")
     .Output("z: T")
     .Attr(
-        "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+        "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
         "complex64, complex128, string}")
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
@@ -315,7 +315,7 @@ REGISTER_OP("AddV2")
     .Input("y: T")
     .Output("z: T")
     .Attr(
-        "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+        "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
         "complex64, complex128}")
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
     .SetIsAggregate()
@@ -412,7 +412,7 @@ REGISTER_OP("Maximum")
     .Input("x: T")
     .Input("y: T")
     .Output("z: T")
-    .Attr("T: {half, bfloat16, float, double, int32, int64}")
+    .Attr("T: {bfloat16, half, float, double, int32, int64}")
     .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
@@ -437,7 +437,7 @@ REGISTER_OP("Minimum")
     .Input("x: T")
     .Input("y: T")
     .Output("z: T")
-    .Attr("T: {half, bfloat16, float, double, int32, int64}")
+    .Attr("T: {bfloat16, half, float, double, int32, int64}")
     .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
@@ -445,21 +445,21 @@ REGISTER_OP("Mod")
     .Input("x: T")
     .Input("y: T")
     .Output("z: T")
-    .Attr("T: {int32, int64, bfloat16, float, double}")
+    .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
 REGISTER_OP("FloorMod")
     .Input("x: T")
     .Input("y: T")
     .Output("z: T")
-    .Attr("T: {int32, int64, bfloat16, float, double}")
+    .Attr("T: {int32, int64, bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
 REGISTER_OP("TruncateMod")
     .Input("x: T")
     .Input("y: T")
     .Output("z: T")
-    .Attr("T: {int32, int64, bfloat16, float, double}")
+    .Attr("T: {int32, int64, bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
 REGISTER_OP("Pow")
@@ -467,7 +467,7 @@ REGISTER_OP("Pow")
     .Input("y: T")
     .Output("z: T")
     .Attr(
-        "T: {half, bfloat16, float, double, int32, int64, complex64, "
+        "T: {bfloat16, float, half, double, int32, int64, complex64, "
         "complex128}")
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
@@ -503,7 +503,7 @@ REGISTER_OP("Atan2")
     .Input("y: T")
     .Input("x: T")
     .Output("z: T")
-    .Attr("T: {bfloat16, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
 REGISTER_OP("Betainc")
@@ -574,7 +574,7 @@ REGISTER_OP("GreaterEqual").COMPARISON();
       .Output("z: bool")                                                   \
       .SetIsCommutative()                                                  \
       .Attr(                                                               \
-          "T: {half, bfloat16, float, double, uint8, int8, int16, int32, " \
+          "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
           "int64, complex64, quint8, qint8, qint32, string, bool, "        \
           "complex128}")                                                   \
       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
@@ -713,7 +713,7 @@ REGISTER_OP("MatMul")
     .Output("product: T")
     .Attr("transpose_a: bool = false")
     .Attr("transpose_b: bool = false")
-    .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
+    .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
     .SetShapeFn(shape_inference::MatMulShape);
 
 REGISTER_OP("SparseMatMul")