[XLA:GPU] Allow merging into input fusion nodes in FusionMerger.
authorJustin Lebar <jlebar@google.com>
Mon, 5 Mar 2018 19:10:42 +0000 (11:10 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 19:15:15 +0000 (11:15 -0800)
Seems to have been an oversight.  "Input fusion" means that the *output*
of the fusion node is the "real hero".  The inputs aren't special; we
can fuse more stuff in.

PiperOrigin-RevId: 187892975

tensorflow/compiler/xla/service/gpu/BUILD
tensorflow/compiler/xla/service/gpu/fusion_merger.cc
tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc

index 334efff..cecbc25 100644 (file)
@@ -437,8 +437,10 @@ tf_cc_test(
         ":fusion_merger",
         ":instruction_fusion",
         "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla/service:hlo_matchers",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
     ],
 )
 
index 91a916f..3cd30b7 100644 (file)
@@ -223,9 +223,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
   // Skip 'fusion' instruction if we cannot merge into all of its users.
   // Merging into all users enables the removal of 'fusion' from the
   // computation.
-  if (!c_all_of(fusion->users(), [](const HloInstruction* instruction) {
-        return instruction->opcode() == HloOpcode::kFusion &&
-               instruction->fusion_kind() == HloInstruction::FusionKind::kLoop;
+  if (!c_all_of(fusion->users(), [](const HloInstruction* user) {
+        return user->opcode() == HloOpcode::kFusion &&
+               (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
+                user->fusion_kind() == HloInstruction::FusionKind::kInput);
       })) {
     VLOG(3) << "Not merging " << fusion->name()
             << ": Some of its users are not loop/input fusion kernels.";
index deef596..c0def27 100644 (file)
@@ -16,13 +16,17 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
 
 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 
 namespace xla {
 namespace gpu {
 namespace {
 
+namespace op = xla::testing::opcode_matchers;
+
 class FusionMergerTest : public HloTestBase {
  protected:
   FusionMergerTest() : module_(CreateNewModule()) {}
@@ -459,6 +463,43 @@ TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) {
   EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie());
 }
 
+// Check that we're willing to merge f1_computation into f2_computation, even
+// though f2 is an input fusion node.
+TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
+  const char* const kModule = R"(
+    HloModule m
+
+    f1_computation {
+      f1_p0 = f32[10]{0} parameter(0)
+      ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0)
+    }
+
+    add_computation {
+      add_lhs = f32[] parameter(0)
+      add_rhs = f32[] parameter(1)
+      ROOT add_root = f32[] add(add_lhs, add_rhs)
+    }
+
+    f2_computation {
+      f2_p0 = f32[10]{0} parameter(0)
+      f2_mul = f32[10]{0} multiply(f2_p0, f2_p0)
+      f2_zero = f32[] constant(0)
+      ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
+             to_apply=add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[10]{0} parameter(0)
+      f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation
+      ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+    }
+  )";
+  auto module = tools::Parse(kModule).ValueOrDie();
+  EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              op::Fusion(op::Parameter()));
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace xla
index 30c88c0..065b3a0 100644 (file)
@@ -535,6 +535,13 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
           // If no operand has a compatible shape, prefer an operand that has
           // the same rank at least.
           for (const HloInstruction* operand : operands) {
+            // Skip tuple-shaped operands; calling ShapeUtil::Rank on a
+            // tuple-shaped Shape is illegal.  Perhaps more correct would be to
+            // recurse into them, but TODO(kramerb): Remove this code after
+            // assigning layouts to fusion nodes.
+            if (ShapeUtil::IsTuple(operand->shape())) {
+              continue;
+            }
             if (ShapeUtil::Rank(*input_shape) ==
                 ShapeUtil::Rank(operand->shape())) {
               // Do not use CopyLayoutBetweenShapes because input_shape and