[XLA] In HloEvaluator, fix an issue for HandleAbs to handle complex numbers
authorKay Zhu <kayzhu@google.com>
Tue, 27 Feb 2018 00:24:54 +0000 (16:24 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
more correctly:

- abs([complex numbers]) would yield floats. However since the specilization for
HandleAbs is based on the return type (float), we'd CHECK fail due to float !=
complex when accessing the elements of the operand (complex).
- enable unary_op_test for interpreter.

PiperOrigin-RevId: 187099576

tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/tests/BUILD

index fd06b19..cf8b359 100644 (file)
@@ -57,6 +57,12 @@ struct is_complex_t : public std::false_type {};
 template <>
 struct is_complex_t<complex64> : public std::true_type {};
 
+template <typename T>
+struct is_complex64_t : public std::false_type {};
+
+template <>
+struct is_complex64_t<complex64> : public std::true_type {};
+
 template <typename OperandT>
 StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
                                            const Literal& lhs_literal,
@@ -248,17 +254,37 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
 
   template <
       typename NativeT,
-      typename std::enable_if<std::is_signed<NativeT>::value ||
-                              is_complex_t<NativeT>::value>::type* = nullptr>
+      typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
   Status HandleAbs(HloInstruction* abs) {
     TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
-                        ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) {
+                        ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
                           return std::abs(elem_operand);
                         }));
     return Status::OK();
   }
 
+  template <
+      typename NativeT,
+      typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr>
+  Status HandleAbs(HloInstruction* abs) {
+    const Literal& operand_literal =
+        parent_->GetEvaluatedLiteralFor(abs->operand(0));
+    TF_ASSIGN_OR_RETURN(
+        parent_->evaluated_[abs],
+        (ElementWiseUnaryOpImpl<float, NativeT>(
+            abs, [](NativeT elem_operand) { return std::abs(elem_operand); },
+            operand_literal)));
+
+    return Status::OK();
+  }
+
   Status HandleAbs(HloInstruction* abs) override {
+    // If the operand is of C64 type, the return type of abs will be F32.
+    // However, ElementwiseT would still be the return type, F32, and thus
+    // specifying the ElementwiseT explicitly as C64 is needed below.
+    if (abs->operand(0)->shape().element_type() == C64) {
+      return HandleAbs<complex64>(abs);
+    }
     return HandleAbs<ElementwiseT>(abs);
   }
 
index 33fde97..f3ecfc1 100644 (file)
@@ -494,6 +494,7 @@ xla_test(
 xla_test(
     name = "unary_op_test",
     srcs = ["unary_op_test.cc"],
+    tags = ["enable_for_xla_interpreter"],
     deps = [
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:computation_builder",