[XLA] Better support for mul reductions in MakeFakeArguments()
authorMichael Kuperstein <mkuper@google.com>
Thu, 5 Apr 2018 21:54:36 +0000 (14:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 21:56:57 +0000 (14:56 -0700)
Mul reductions want a 1 as their init value, not a 0 or a random value.

PiperOrigin-RevId: 191802819

tensorflow/compiler/xla/tests/test_utils.cc

index 821432e..68f75d5 100644 (file)
@@ -160,27 +160,38 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
   return std::move(literal);
 }
 
-// Matches binary addition computations.
-bool LooksLikeSum(const HloComputation& computation) {
+enum class ConstantType { kUnknown, kZero, kOne };
+
+// Return the constant type required by this computation, if known.
+ConstantType GetInitValue(const HloComputation& computation) {
   const HloInstruction* const root = computation.root_instruction();
-  return root->opcode() == HloOpcode::kAdd &&
-         computation.num_parameters() == 2 &&
-         root->operand(0)->opcode() == HloOpcode::kParameter &&
-         root->operand(1)->opcode() == HloOpcode::kParameter &&
-         root->operand(0) != root->operand(1);
+  if (computation.num_parameters() != 2 ||
+      root->operand(0)->opcode() != HloOpcode::kParameter ||
+      root->operand(1)->opcode() != HloOpcode::kParameter ||
+      root->operand(0) == root->operand(1)) {
+    return ConstantType::kUnknown;
+  }
+
+  switch (root->opcode()) {
+    case HloOpcode::kAdd:
+      return ConstantType::kZero;
+    case HloOpcode::kMultiply:
+      return ConstantType::kOne;
+    default:
+      return ConstantType::kUnknown;
+  }
 }
 
-// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition,
-// which requires an init_value of 0 rather than a random value.
-bool NeedsZeroInitValue(const HloUse& use) {
+// Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
+// initialization value.
+bool NeedsInitValue(const HloUse& use) {
   const HloInstruction* const instruction = use.instruction;
   const HloOpcode opcode = instruction->opcode();
   const int64 op_num = use.operand_number;
   return (
       ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
-       op_num == 1 && LooksLikeSum(*instruction->to_apply())) ||
-      (opcode == HloOpcode::kSelectAndScatter && op_num == 2 &&
-       LooksLikeSum(*instruction->scatter())));
+       op_num == 1) ||
+      (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
 }
 
 // Generate random values that are constrained to the input_shape minus the
@@ -222,7 +233,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
         auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
         constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
                                 fused_uses.end());
-      } else if (NeedsZeroInitValue(use)) {
+      } else if (NeedsInitValue(use)) {
         constrained_uses.push_back(instruction);
       } else if (opcode == HloOpcode::kConvert ||
                  opcode == HloOpcode::kReducePrecision) {
@@ -243,7 +254,8 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
     const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
     const HloInstruction& param, std::minstd_rand0* engine) {
   HloInstruction* needs_index = nullptr;
-  HloInstruction* needs_zero = nullptr;
+  HloInstruction* needs_constant = nullptr;
+  ConstantType constant_type = ConstantType::kUnknown;
   for (HloInstruction* use : constrained_uses) {
     switch (use->opcode()) {
       case HloOpcode::kDynamicSlice:
@@ -258,8 +270,13 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
 
       case HloOpcode::kReduce:
       case HloOpcode::kReduceWindow:
+        needs_constant = use;
+        constant_type = GetInitValue(*use->to_apply());
+        break;
+
       case HloOpcode::kSelectAndScatter:
-        needs_zero = use;
+        needs_constant = use;
+        constant_type = GetInitValue(*use->scatter());
         break;
 
       default:
@@ -268,17 +285,26 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
             use->ToString().c_str());
     }
   }
-  if (needs_index != nullptr && needs_zero != nullptr) {
+  if (needs_index != nullptr && needs_constant != nullptr) {
     return Unimplemented(
         "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
-        "zero: %s\n",
-        needs_index->ToString().c_str(), needs_zero->ToString().c_str());
+        "constant: %s\n",
+        needs_index->ToString().c_str(), needs_constant->ToString().c_str());
   }
   if (needs_index != nullptr) {
     return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
                                            needs_index->shape(), engine);
-  } else if (needs_zero != nullptr) {
-    return Literal::CreateFromShape(param.shape());
+  } else if (needs_constant != nullptr) {
+    switch (constant_type) {
+      case ConstantType::kZero:
+        return Literal::Zero(param.shape().element_type()).CloneToUnique();
+      case ConstantType::kOne:
+        return Literal::One(param.shape().element_type()).CloneToUnique();
+      case ConstantType::kUnknown:
+        // We want the identity element for the computation, but we don't really
+        // know what it is - so any value we generate will be just as wrong.
+        return MakeFakeLiteralInternal(param.shape(), engine);
+    }
   } else {
     return MakeFakeLiteralInternal(param.shape(), engine);
   }