[XLA:GPU] Teach ir_emitter_nested how to deal with multi output loop fusion
authorBenjamin Kramer <kramerb@google.com>
Wed, 16 May 2018 17:40:57 +0000 (10:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 17:45:23 +0000 (10:45 -0700)
Most of the plumbing is there already, just set up a loop emitter with a target
for each tuple element. For a simple case the output looks reasonable, though I
haven't checked correctness of anything complex.

PiperOrigin-RevId: 196850926

tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
tensorflow/compiler/xla/tests/multioutput_fusion_test.cc

index 71aada0..f837a60 100644 (file)
@@ -116,6 +116,17 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
 Status IrEmitterNested::EmitTargetElementLoop(
     const HloInstruction& hlo,
     const llvm_ir::ElementGenerator& element_generator) {
+  // For MOF we give the loop emitter an array for every output it should
+  // generate.
+  if (hlo.IsMultiOutputFusion()) {
+    std::vector<llvm_ir::IrArray> target_arrays;
+    for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e;
+         ++i) {
+      target_arrays.push_back(GetIrArray(hlo, hlo, {i}));
+    }
+    return llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_)
+        .EmitLoop();
+  }
   return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo),
                               &ir_builder_)
       .EmitLoop();
index b745522..413107e 100644 (file)
@@ -210,5 +210,33 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
       *result, *Literal::MakeTupleOwned(Literal::CreateR0<int32>(42))));
 }
 
+XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
+  const char* testcase = R"(
+    HloModule m
+
+    fused_computation {
+      p = f32[] parameter(0)
+      multiply = f32[] multiply(p, p)
+      less-than = pred[] less-than(p, multiply)
+      ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
+    }
+
+    ENTRY PredFloatMOF {
+      p0 = f32[] parameter(0)
+      fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation
+      gte0 = pred[] get-tuple-element(fusion), index=0
+      gte1 = f32[] get-tuple-element(fusion), index=1
+      const = f32[] constant(0)
+      ROOT select = f32[] select(gte0, gte1, const)
+    })";
+  auto module =
+      HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+          .ValueOrDie();
+  auto param = Literal::CreateR0<float>(2.0);
+  TF_ASSERT_OK_AND_ASSIGN(auto result,
+                          Execute(std::move(module), {param.get()}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateR0<float>(4.0)));
+}
+
 }  // namespace
 }  // namespace xla