From 0f8be44de22a344ce6aac1e2cee8595b7c89d9f8 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 17 May 2018 11:06:05 -0700 Subject: [PATCH] [XLA:GPU] Unroll multi-output loop fusions This is easier than I thought because we can assume that all tuple members have the same number of elements. LLVM doesn't do a great job of vectorizing the resulting stores, but otherwise this is working fine. PiperOrigin-RevId: 197019718 --- .../xla/service/gpu/ir_emitter_unnested.cc | 19 +++++++--------- .../compiler/xla/tests/multioutput_fusion_test.cc | 25 +++++++++++----------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 0d7ba4c..957733f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -267,7 +267,10 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Find the largest possible power of two to unroll by. // TODO(kramerb): Make this smarter. - int64 num_elements = ShapeUtil::ElementsIn(hlo->shape()); + const Shape& element_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + int64 num_elements = ShapeUtil::ElementsIn(element_shape); for (int i = max_unroll_factor; i > 1; i /= 2) { if (num_elements % i == 0) { return i; @@ -565,12 +568,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int unroll_factor = 1; - // TODO(kramerb): Unrolling multi-output loop fusions too. - if (!fusion->IsMultiOutputFusion()) { - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - unroll_factor = ComputeMaxUnrollFactor(fusion); - } + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int unroll_factor = ComputeMaxUnrollFactor(fusion); thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); @@ -2538,16 +2537,14 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( .EmitLoop(IrName(&hlo)); } - CHECK_EQ(unroll_factor, 1) - << "multi-output fusion does not support unrolling"; - // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, + unroll_factor) .EmitLoop(IrName(&hlo))); std::vector tuple_operand_ptrs; diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 39f9bba..ec7ca20 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -215,27 +215,28 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloModule m fused_computation { - p = f32[] parameter(0) - multiply = f32[] multiply(p, p) - less-than = pred[] less-than(p, multiply) - ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) + p = f32[4] parameter(0) + multiply = f32[4] multiply(p, p) + less-than = pred[4] less-than(p, multiply) + ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) } ENTRY PredFloatMOF { - p0 = f32[] parameter(0) - fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation - gte0 = pred[] get-tuple-element(fusion), index=0 - gte1 = f32[] get-tuple-element(fusion), index=1 - const = f32[] constant(0) - ROOT select = f32[] select(gte0, gte1, const) + p0 = f32[4] parameter(0) + fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[4] get-tuple-element(fusion), index=0 + gte1 = f32[4] get-tuple-element(fusion), index=1 + const = f32[4] constant({0, 0, 0, 0}) + ROOT select = f32[4] select(gte0, gte1, const) })"; auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR0(2.0); + auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateR0(4.0))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { -- 2.7.4