[XLA:GPU] Fuse broadcasts into reduction fusions
authorBenjamin Kramer <kramerb@google.com>
Wed, 7 Mar 2018 14:28:00 +0000 (06:28 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 14:32:08 +0000 (06:32 -0800)
We didn't do this because reconstructing a layout was hard. With
layout_assignment before fusion this becomes much easier. Remove the
limitations.

PiperOrigin-RevId: 188167436

tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc

index 870d241..84504d2 100644 (file)
@@ -71,17 +71,6 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
     return false;
   }
 
-  // We may need to know original operand layout to emit input fusion, and so
-  // far, we merely use the layout of an operand of the fusion node, which means
-  // we must fuse only elementwise operations. This restriction should be lifted
-  // later if we need to fuse other operations, e.g. transpose, for performance.
-  if ((IsReductionToVector(*consumer) ||
-       (HloOpcode::kFusion == consumer->opcode() &&
-        HloInstruction::FusionKind::kInput == consumer->fusion_kind())) &&
-      !producer->IsElementwise()) {
-    return false;
-  }
-
   // Cost condition: not fuse (simple, expensive producers) and (consumers who
   // reuse operand elements).
   if (producer->opcode() != HloOpcode::kFusion &&
index 373e5a5..c81dbb7 100644 (file)
@@ -164,6 +164,36 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) {
   EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode());
 }
 
+// Tests that broadcasts fused into a fusion with a reduce root.
+TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
+  auto module = tools::Parse(R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    ENTRY BroadcastIntoReduce {
+      constant = f32[] constant(1)
+      broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={}
+      constant.1 = f32[] constant(0)
+      ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
+                                                         to_apply=add
+    })")
+                    .ValueOrDie();
+
+  EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+                  .Run(module.get())
+                  .ValueOrDie());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Fusion());
+  EXPECT_THAT(root->fused_expression_root(),
+              op::Reduce(op::Broadcast(op::Parameter()), op::Parameter()));
+}
+
 TEST_F(InstructionFusionTest, BitcastIntoAdd) {
   auto module = tools::Parse(R"(
     HloModule test_module
index 065b3a0..4cfb613 100644 (file)
@@ -517,46 +517,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
         TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
 
         Shape input_shape = root->operand(0)->shape();
-        // EmitReductionToVector requires the input shape to have a layout, but
-        // fused instructions don't have one. So we determine its layout from
-        // the fusion's operands. The choice of the layout only affects
-        // performance but not correctness.
-        auto choose_input_layout = [](
-            tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
-            Shape* input_shape) -> Status {
-          // Prefer the layout of an operand whose shape is compatible with
-          // input_shape.
-          for (const HloInstruction* operand : operands) {
-            if (ShapeUtil::Compatible(*input_shape, operand->shape())) {
-              return LayoutUtil::CopyLayoutBetweenShapes(operand->shape(),
-                                                         input_shape);
-            }
-          }
-          // If no operand has a compatible shape, prefer an operand that has
-          // the same rank at least.
-          for (const HloInstruction* operand : operands) {
-            // Skip tuple-shaped operands; calling ShapeUtil::Rank on a
-            // tuple-shaped Shape is illegal.  Perhaps more correct would be to
-            // recurse into them, but TODO(kramerb): Remove this code after
-            // assigning layouts to fusion nodes.
-            if (ShapeUtil::IsTuple(operand->shape())) {
-              continue;
-            }
-            if (ShapeUtil::Rank(*input_shape) ==
-                ShapeUtil::Rank(operand->shape())) {
-              // Do not use CopyLayoutBetweenShapes because input_shape and
-              // operand->shape() may be incompatible.
-              *input_shape->mutable_layout() = operand->shape().layout();
-              return Status::OK();
-            }
-          }
-          // When all the above fails, which is rare, set the default layout.
-          LayoutUtil::SetToDefaultLayout(input_shape);
-          return Status::OK();
-        };
-        TF_RETURN_IF_ERROR(
-            choose_input_layout(fusion->operands(), &input_shape));
-
         return EmitReductionToVector(
             root, input_shape, fused_emitter.GetGenerator(root->operand(0)),
             fused_emitter.GetGenerator(root->operand(1)), root->dimensions(),