Status IrEmitterNested::EmitTargetElementLoop(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator) {
+ // For MOF we give the loop emitter an array for every output it should
+ // generate.
+ if (hlo.IsMultiOutputFusion()) {
+ std::vector<llvm_ir::IrArray> target_arrays;
+ for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e;
+ ++i) {
+ target_arrays.push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ return llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_)
+ .EmitLoop();
+ }
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo),
&ir_builder_)
.EmitLoop();
*result, *Literal::MakeTupleOwned(Literal::CreateR0<int32>(42))));
}
+XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
+ const char* testcase = R"(
+ HloModule m
+
+ fused_computation {
+ p = f32[] parameter(0)
+ multiply = f32[] multiply(p, p)
+ less-than = pred[] less-than(p, multiply)
+ ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
+ }
+
+ ENTRY PredFloatMOF {
+ p0 = f32[] parameter(0)
+ fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation
+ gte0 = pred[] get-tuple-element(fusion), index=0
+ gte1 = f32[] get-tuple-element(fusion), index=1
+ const = f32[] constant(0)
+ ROOT select = f32[] select(gte0, gte1, const)
+ })";
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR0<float>(2.0);
+ TF_ASSERT_OK_AND_ASSIGN(auto result,
+ Execute(std::move(module), {param.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateR0<float>(4.0)));
+}
+
} // namespace
} // namespace xla