[XLA:GPU] Add lowering for input fusions with multiple reduce outputs
authorBenjamin Kramer <kramerb@google.com>
Tue, 22 May 2018 18:52:51 +0000 (11:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 18:55:39 +0000 (11:55 -0700)
This is limited to reduces that have the same shapes and reduced dimensions.
Most of the code is making the individual emission code emit multiple reduction
in the same loop. This requires multi-output fusion to provide a speedup.

PiperOrigin-RevId: 197599248

tensorflow/compiler/xla/service/gpu/ir_emitter.h
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
tensorflow/compiler/xla/tests/multioutput_fusion_test.cc

index b0accc0..e55dfc6 100644 (file)
@@ -120,10 +120,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
   llvm::Value* GetBasePointer(const HloInstruction& inst) const {
     return bindings_.GetBasePointer(inst);
   }
-  // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice.
-  BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const {
+  // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+  BufferAllocation::Slice GetAllocationSlice(
+      const HloInstruction& hlo, const ShapeIndex& index = {}) const {
     return ir_emitter_context_->buffer_assignment()
-        .GetUniqueTopLevelSlice(&hlo)
+        .GetUniqueSlice(&hlo, index)
         .ConsumeValueOrDie();
   }
 
index 55d4c1d..d07d197 100644 (file)
@@ -79,6 +79,7 @@ namespace {
 
 using llvm_ir::IrName;
 using tensorflow::gtl::ArraySlice;
+using tensorflow::gtl::InlinedVector;
 using tensorflow::gtl::nullopt;
 using tensorflow::gtl::optional;
 using tensorflow::strings::StrCat;
@@ -499,12 +500,24 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
   // initializes the output array to the initial value of the reduce.
   if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) {
     switch (root->opcode()) {
+      case HloOpcode::kTuple:
       case HloOpcode::kReduce: {
         VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
-        TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
-                            BuildInitializerThunk(fusion));
         std::vector<std::unique_ptr<Thunk>> thunks;
-        thunks.push_back(std::move(initializer_thunk));
+        ArraySlice<HloInstruction*> reduces =
+            root->opcode() == HloOpcode::kTuple
+                ? root->operands()
+                : ArraySlice<HloInstruction*>(&root, 1);
+
+        // For multi-output fusion emit an initializer for each tuple element.
+        // Otherwise it's sufficient to just initialize the single output.
+        for (int i = 0, e = reduces.size(); i != e; ++i) {
+          TF_ASSIGN_OR_RETURN(
+              std::unique_ptr<Thunk> initializer_thunk,
+              BuildInitializerThunk(
+                  fusion, reduces[i] == root ? ShapeIndex() : ShapeIndex({i})));
+          thunks.push_back(std::move(initializer_thunk));
+        }
         thunks.push_back(BuildKernelThunk(fusion));
         thunk_sequence_->emplace_back(
             MakeUnique<SequentialThunk>(std::move(thunks), fusion));
@@ -518,11 +531,34 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
         FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
         TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
 
-        Shape input_shape = root->operand(0)->shape();
-        return EmitReductionToVector(
-            root, input_shape, fused_emitter.GetGenerator(root->operand(0)),
-            fused_emitter.GetGenerator(root->operand(1)), root->dimensions(),
-            root->to_apply());
+        // For multi-output fusion CHECK the constraints and feed all the
+        // reduces into a single loop code generator. Single-output reduce
+        // fusion is a special case of that.
+        InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
+        InlinedVector<llvm_ir::ElementGenerator, 1> init_value_gens;
+        InlinedVector<HloComputation*, 1> reducers;
+        for (const HloInstruction* reduce : reduces) {
+          CHECK_EQ(HloOpcode::kReduce, reduce->opcode());
+          // TODO(kramerb): CHECK that layouts are equal. Currently this
+          // breaks multioutputfusion_test. The test has pre-fused
+          // instructions, but layout_assignment will not assign any layouts
+          // for instructions inside of a fused computation. It just removes
+          // the layouts instead.
+          CHECK(ShapeUtil::Compatible(reduces[0]->shape(), reduce->shape()));
+          CHECK(ShapeUtil::Compatible(reduces[0]->operand(0)->shape(),
+                                      reduce->operand(0)->shape()));
+          CHECK(ShapeUtil::Compatible(reduces[0]->operand(1)->shape(),
+                                      reduce->operand(1)->shape()));
+          CHECK(reduces[0]->dimensions() == reduce->dimensions());
+          input_gens.push_back(fused_emitter.GetGenerator(reduce->operand(0)));
+          init_value_gens.push_back(
+              fused_emitter.GetGenerator(reduce->operand(1)));
+          reducers.push_back(reduce->to_apply());
+        }
+        const Shape& input_shape = reduces[0]->operand(0)->shape();
+        return EmitReductionToVector(reduces[0], input_shape, input_gens,
+                                     init_value_gens, reduces[0]->dimensions(),
+                                     reducers);
       }
       default:
         LOG(FATAL) << "Bad opcode for input fusion: "
@@ -909,8 +945,9 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
 
 Status IrEmitterUnnested::EmitReductionToScalar(
     HloInstruction* reduce, const Shape& input_shape,
-    const llvm_ir::ElementGenerator& input_gen,
-    const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+    tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
   // Number of elements processed by a single thread.
   constexpr int64 kTileSize = 16;
   int64 num_elems = ShapeUtil::ElementsIn(input_shape);
@@ -962,14 +999,19 @@ Status IrEmitterUnnested::EmitReductionToScalar(
   //
   auto loop_body_emitter =
       [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+    const int num_reduces = reducers.size();
     llvm::Type* element_ir_type =
         llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
-    llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
-        element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
-    {
-      TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
-                          init_value_gen(llvm_ir::IrArray::Index({})));
+    std::vector<llvm::Value*> partial_reduction_result_addresses;
+    for (int i = 0; i != num_reduces; ++i) {
+      llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+          element_ir_type, /*ArraySize=*/nullptr,
+          "partial_reduction_result." + llvm::Twine(i));
+      TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+                          init_value_gens[i](llvm_ir::IrArray::Index({})));
       ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+      partial_reduction_result_addresses.push_back(
+          partial_reduction_result_address);
     }
 
     llvm::Value* x_in_tiles = tile_index[0];
@@ -1002,11 +1044,16 @@ Status IrEmitterUnnested::EmitReductionToScalar(
       llvm_ir::IrArray::Index input_index(
           /*linear=*/x, input_shape, &ir_builder_);
       llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
-      TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index));
-      ir_builder_.CreateStore(input_ir_value, input_address);
-      return (EmitCallToNestedComputation(
-          *reducer, {partial_reduction_result_address, input_address},
-          partial_reduction_result_address));
+      for (int i = 0; i != num_reduces; ++i) {
+        TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+                            input_gens[i](input_index));
+        ir_builder_.CreateStore(input_ir_value, input_address);
+        TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+            *reducers[i],
+            {partial_reduction_result_addresses[i], input_address},
+            partial_reduction_result_addresses[i]));
+      }
+      return Status::OK();
     };
 
     // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
@@ -1041,20 +1088,24 @@ Status IrEmitterUnnested::EmitReductionToScalar(
                                       : element_ir_type;
     for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
          shuffle_distance /= 2) {
-      llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
-          ir_builder_.CreateBitCast(partial_reduction_result_address,
-                                    shuffle_ir_type->getPointerTo()),
-          "partial_reduction_result");
       llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
           element_ir_type, nullptr, "result_from_other_lane");
-      ir_builder_.CreateStore(
-          EmitShuffleDown(partial_reduction_result,
-                          ir_builder_.getInt32(shuffle_distance), &ir_builder_),
-          ir_builder_.CreateBitCast(result_from_other_lane,
-                                    shuffle_ir_type->getPointerTo()));
-      TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
-          *reducer, {partial_reduction_result_address, result_from_other_lane},
-          partial_reduction_result_address));
+      for (int i = 0; i != num_reduces; ++i) {
+        llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
+            ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
+                                      shuffle_ir_type->getPointerTo()),
+            "partial_reduction_result");
+        ir_builder_.CreateStore(
+            EmitShuffleDown(partial_reduction_result,
+                            ir_builder_.getInt32(shuffle_distance),
+                            &ir_builder_),
+            ir_builder_.CreateBitCast(result_from_other_lane,
+                                      shuffle_ir_type->getPointerTo()));
+        TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+            *reducers[i],
+            {partial_reduction_result_addresses[i], result_from_other_lane},
+            partial_reduction_result_addresses[i]));
+      }
     }
 
     const HloInstruction* output =
@@ -1070,14 +1121,25 @@ Status IrEmitterUnnested::EmitReductionToScalar(
         "lane_id_is_zero", &ir_builder_);
     llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
                                    &ir_builder_);
-    llvm::Value* output_address =
-        GetIrArray(*output, *output)
-            .EmitArrayElementAddress(
-                llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0),
-                                        output->shape(), &ir_builder_),
-                &ir_builder_, "output_element_address");
-    return EmitAtomicOperationForNestedComputation(
-        *reducer, output_address, partial_reduction_result_address);
+
+    for (int i = 0; i != num_reduces; ++i) {
+      ShapeIndex output_shape_index;
+      if (output->IsMultiOutputFusion()) {
+        output_shape_index = {i};
+      }
+      llvm::Value* output_address =
+          GetIrArray(*output, *output, output_shape_index)
+              .EmitArrayElementAddress(
+                  llvm_ir::IrArray::Index(
+                      /*linear=*/ir_builder_.getInt64(0),
+                      ShapeUtil::GetSubshape(output->shape(),
+                                             output_shape_index),
+                      &ir_builder_),
+                  &ir_builder_, "output_element_address");
+      TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+          *reducers[i], output_address, partial_reduction_result_addresses[i]));
+    }
+    return Status::OK();
   };
 
   // Emit a parallel loop that iterates through all input tiles, one per thread.
@@ -1097,8 +1159,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
 
 Status IrEmitterUnnested::EmitColumnReduction(
     int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
-    const llvm_ir::ElementGenerator& input_gen,
-    const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+    tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
   // Divide the input matrix into tiles of size Kx1. For example, when the
   // input matrix is 4x4 and K=2, the tiled matrix looks like
   //
@@ -1140,15 +1203,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
   // }
   auto loop_body_emitter =
       [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+    const int num_reduces = reducers.size();
     // Emit the loop body that reduces one tile.
     llvm::Type* element_ir_type =
         llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
-    llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
-        element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
-    {
-      TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
-                          init_value_gen(llvm_ir::IrArray::Index({})));
+    std::vector<llvm::Value*> partial_reduction_result_addresses;
+    for (int i = 0; i != num_reduces; ++i) {
+      llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+          element_ir_type, /*ArraySize=*/nullptr,
+          "partial_reduction_result." + llvm::Twine(i));
+      TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+                          init_value_gens[i](llvm_ir::IrArray::Index({})));
       ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+      partial_reduction_result_addresses.push_back(
+          partial_reduction_result_address);
     }
 
     // Emit an inner for-loop that partially reduces the elements in the given
@@ -1206,13 +1274,17 @@ Status IrEmitterUnnested::EmitColumnReduction(
                 .SourceIndexOfTranspose(normalized_input_shape, input_shape,
                                         transpose_dimension_mapping,
                                         &ir_builder_);
-        TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
-                            input_gen(input_index));
-        ir_builder_.CreateStore(input_ir_value, input_address);
+        for (int i = 0; i != num_reduces; ++i) {
+          TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+                              input_gens[i](input_index));
+          ir_builder_.CreateStore(input_ir_value, input_address);
+          TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+              *reducers[i],
+              {partial_reduction_result_addresses[i], input_address},
+              partial_reduction_result_addresses[i]));
+        }
+        return Status::OK();
       }
-      return (EmitCallToNestedComputation(
-          *reducer, {partial_reduction_result_address, input_address},
-          partial_reduction_result_address));
     };
 
     // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
@@ -1241,13 +1313,24 @@ Status IrEmitterUnnested::EmitColumnReduction(
                                    &ir_builder_);
     const HloInstruction* output =
         reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
-    llvm::Value* output_address =
-        GetIrArray(*output, *output)
-            .EmitArrayElementAddress(
-                llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_),
-                &ir_builder_, "output_element_address");
-    return EmitAtomicOperationForNestedComputation(
-        *reducer, output_address, partial_reduction_result_address);
+    for (int i = 0; i != num_reduces; ++i) {
+      ShapeIndex output_shape_index;
+      if (output->IsMultiOutputFusion()) {
+        output_shape_index = {i};
+      }
+      llvm::Value* output_address =
+          GetIrArray(*output, *output, output_shape_index)
+              .EmitArrayElementAddress(
+                  llvm_ir::IrArray::Index(
+                      x,
+                      ShapeUtil::GetSubshape(output->shape(),
+                                             output_shape_index),
+                      &ir_builder_),
+                  &ir_builder_, "output_element_address");
+      TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+          *reducers[i], output_address, partial_reduction_result_addresses[i]));
+    }
+    return Status::OK();
   };
 
   // Emit a parallel loop that iterate through all input tiles.
@@ -1267,8 +1350,10 @@ Status IrEmitterUnnested::EmitColumnReduction(
 
 Status IrEmitterUnnested::EmitRowReduction(
     int64 depth, int64 height, int64 width, HloInstruction* reduce,
-    const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen,
-    const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+    const Shape& input_shape,
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+    tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
   // A naive algorithm is:
   // 1. Divide the input tensor into tiles of size 1x1xK.
   // 2. Partially reduces each tile to a scalar using one thread.
@@ -1358,15 +1443,20 @@ Status IrEmitterUnnested::EmitRowReduction(
 
   auto loop_body_emitter =
       [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+    const int num_reduces = reducers.size();
     // Emit the loop body that reduces one tile.
     llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
         input_shape.element_type(), ir_emitter_context_->llvm_module());
-    llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
-        element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
-    {
-      TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
-                          init_value_gen(llvm_ir::IrArray::Index({})));
+    std::vector<llvm::Value*> partial_reduction_result_addresses;
+    for (int i = 0; i != num_reduces; ++i) {
+      llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+          element_ir_type, /*ArraySize=*/nullptr,
+          "partial_reduction_result." + llvm::Twine(i));
+      TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+                          init_value_gens[i](llvm_ir::IrArray::Index({})));
       ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+      partial_reduction_result_addresses.push_back(
+          partial_reduction_result_address);
     }
 
     // Emit an inner for-loop that partially reduces the elements in the given
@@ -1449,13 +1539,17 @@ Status IrEmitterUnnested::EmitRowReduction(
                 .SourceIndexOfTranspose(normalized_input_shape, input_shape,
                                         transpose_dimension_mapping,
                                         &ir_builder_);
-        TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
-                            input_gen(input_index));
-        ir_builder_.CreateStore(input_ir_value, input_address);
+        for (int i = 0; i != num_reduces; ++i) {
+          TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+                              input_gens[i](input_index));
+          ir_builder_.CreateStore(input_ir_value, input_address);
+          TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+              *reducers[i],
+              {partial_reduction_result_addresses[i], input_address},
+              partial_reduction_result_addresses[i]));
+        }
+        return Status::OK();
       }
-      return EmitCallToNestedComputation(
-          *reducer, {partial_reduction_result_address, input_address},
-          partial_reduction_result_address);
     };
 
     llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
@@ -1483,20 +1577,24 @@ Status IrEmitterUnnested::EmitRowReduction(
                                       : element_ir_type;
     for (int shuffle_distance = 16; shuffle_distance >= 1;
          shuffle_distance /= 2) {
-      llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
-          ir_builder_.CreateBitCast(partial_reduction_result_address,
-                                    shuffle_ir_type->getPointerTo()),
-          "partial_reduction_result");
       llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
           element_ir_type, nullptr, "result_from_other_lane");
-      ir_builder_.CreateStore(
-          EmitShuffleDown(partial_reduction_result,
-                          ir_builder_.getInt32(shuffle_distance), &ir_builder_),
-          ir_builder_.CreateBitCast(result_from_other_lane,
-                                    shuffle_ir_type->getPointerTo()));
-      TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
-          *reducer, {partial_reduction_result_address, result_from_other_lane},
-          partial_reduction_result_address));
+      for (int i = 0; i != num_reduces; ++i) {
+        llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
+            ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
+                                      shuffle_ir_type->getPointerTo()),
+            "partial_reduction_result");
+        ir_builder_.CreateStore(
+            EmitShuffleDown(partial_reduction_result,
+                            ir_builder_.getInt32(shuffle_distance),
+                            &ir_builder_),
+            ir_builder_.CreateBitCast(result_from_other_lane,
+                                      shuffle_ir_type->getPointerTo()));
+        TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+            *reducers[i],
+            {partial_reduction_result_addresses[i], result_from_other_lane},
+            partial_reduction_result_addresses[i]));
+      }
     }
 
     const HloInstruction* output =
@@ -1510,13 +1608,24 @@ Status IrEmitterUnnested::EmitRowReduction(
         "lane_id_is_zero", &ir_builder_);
     llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
                                    &ir_builder_);
-    llvm::Value* output_address =
-        GetIrArray(*output, *output)
-            .EmitArrayElementAddress(
-                llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_),
-                &ir_builder_, "output_element_address");
-    return EmitAtomicOperationForNestedComputation(
-        *reducer, output_address, partial_reduction_result_address);
+    for (int i = 0; i != num_reduces; ++i) {
+      ShapeIndex output_shape_index;
+      if (output->IsMultiOutputFusion()) {
+        output_shape_index = {i};
+      }
+      llvm::Value* output_address =
+          GetIrArray(*output, *output, output_shape_index)
+              .EmitArrayElementAddress(
+                  llvm_ir::IrArray::Index(
+                      y,
+                      ShapeUtil::GetSubshape(output->shape(),
+                                             output_shape_index),
+                      &ir_builder_),
+                  &ir_builder_, "output_element_address");
+      TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+          *reducers[i], output_address, partial_reduction_result_addresses[i]));
+    }
+    return Status::OK();
   };
 
   // Emit a parallel loop that iterates through every input tiles.
@@ -1543,10 +1652,10 @@ Status IrEmitterUnnested::EmitRowReduction(
 //               elementwise.
 Status IrEmitterUnnested::EmitReductionToVector(
     HloInstruction* reduce, const Shape& input_shape,
-    const llvm_ir::ElementGenerator& input_gen,
-    const llvm_ir::ElementGenerator& init_value_gen,
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+    tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
     tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
-    HloComputation* reducer) {
+    tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
   // This emission requires "reduce" to have an input layout. It is either set
   // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
   // a fused kReduce).
@@ -1581,8 +1690,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
   // `EmitReductionToVector`, we only need to check whether the minormost
   // dimension of the input is to keep.
   if (input_dims_to_keep.empty()) {
-    return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen,
-                                 reducer);
+    return EmitReductionToScalar(reduce, input_shape, input_gens,
+                                 init_value_gens, reducers);
   } else if (input_dims_to_keep.front() ==
              LayoutUtil::Minor(input_shape.layout(), 0)) {
     // Column reduction. Treat the result of "input" as a matrix whose width
@@ -1599,8 +1708,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
         height *= input_shape.dimensions(input_dim);
       }
     }
-    return EmitColumnReduction(height, width, reduce, input_shape, input_gen,
-                               init_value_gen, reducer);
+    return EmitColumnReduction(height, width, reduce, input_shape, input_gens,
+                               init_value_gens, reducers);
   } else {
     // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a
     // 3D tensor. The size of dimension 1 (the height) is the size of the
@@ -1626,7 +1735,7 @@ Status IrEmitterUnnested::EmitReductionToVector(
     }
     const int64 height = ShapeUtil::ElementsIn(reduce->shape());
     return EmitRowReduction(depth, height, width, reduce, input_shape,
-                            input_gen, init_value_gen, reducer);
+                            input_gens, init_value_gens, reducers);
   }
 }
 
@@ -1650,16 +1759,15 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
         MakeUnique<SequentialThunk>(std::move(thunks), reduce));
 
     return EmitReductionToVector(
-        reduce, input->shape(),
-        [&](const llvm_ir::IrArray::Index& index) {
+        reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) {
           return GetIrArray(*input, *reduce)
               .EmitReadArrayElement(index, &ir_builder_);
-        },
-        [&](const llvm_ir::IrArray::Index& index) {
+        }},
+        {[&](const llvm_ir::IrArray::Index& index) {
           return GetIrArray(*init_value, *reduce)
               .EmitReadArrayElement(index, &ir_builder_);
-        },
-        dimensions_to_reduce, reducer);
+        }},
+        dimensions_to_reduce, {reducer});
   }
 
   thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
@@ -2324,7 +2432,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
 }
 
 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
-    const HloInstruction* hlo) {
+    const HloInstruction* hlo, const ShapeIndex& index) {
   bool fused = HloOpcode::kFusion == hlo->opcode();
   const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
   const HloInstruction* init_value = [&] {
@@ -2333,6 +2441,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
         return inst->operand(2);
       case HloOpcode::kReduce:
         return inst->operand(1);
+      case HloOpcode::kTuple:
+        CHECK(hlo->IsMultiOutputFusion() &&
+              inst->operand(index.back())->opcode() == HloOpcode::kReduce);
+        // For multi-output fusion look through the tuple.
+        return inst->operand(index.back())->operand(1);
       default:
         LOG(FATAL) << "Opcode " << inst->opcode()
                    << " should not need an initializer.";
@@ -2356,7 +2469,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
     ArraySlice<uint8> literal_bytes(
         reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
     if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
-      return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo), hlo)};
+      return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), hlo)};
     }
 
     // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2372,8 +2485,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
         pattern16 = literal_bytes.front();
       }
       uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
-      return {MakeUnique<Memset32BitValueThunk>(pattern32,
-                                                GetAllocationSlice(*hlo), hlo)};
+      return {MakeUnique<Memset32BitValueThunk>(
+          pattern32, GetAllocationSlice(*hlo, index), hlo)};
     }
 
     // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@@ -2383,8 +2496,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
                literal_bytes.size() - 4) == 0) {
       uint32 word;
       memcpy(&word, literal_bytes.data(), sizeof(word));
-      return {MakeUnique<Memset32BitValueThunk>(word, GetAllocationSlice(*hlo),
-                                                hlo)};
+      return {MakeUnique<Memset32BitValueThunk>(
+          word, GetAllocationSlice(*hlo, index), hlo)};
     }
   }
 
index 14780de..a1d4dca 100644 (file)
@@ -110,28 +110,31 @@ class IrEmitterUnnested : public IrEmitter {
   // `EmitReductionToVector`. Note that input shape might not be
   // [height x width], but can be bitcast to [height x weight] with "height"
   // being the major dimension.
-  Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
-                             const Shape& input_shape,
-                             const llvm_ir::ElementGenerator& input_gen,
-                             const llvm_ir::ElementGenerator& init_value_gen,
-                             HloComputation* reducer);
+  Status EmitColumnReduction(
+      int64 height, int64 width, HloInstruction* reduce,
+      const Shape& input_shape,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+      tensorflow::gtl::ArraySlice<HloComputation*> reducers);
 
   // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
   // vector of shape [height]. Other parameters have the same meaning as those
   // of `EmitReductionToVector`. Note that input shape might not be
   // [depth x height x width], but can be bitcast to [depth x height x weight]
   // with "depth" being the most major dimension.
-  Status EmitRowReduction(int64 depth, int64 height, int64 width,
-                          HloInstruction* reduce, const Shape& input_shape,
-                          const llvm_ir::ElementGenerator& input_gen,
-                          const llvm_ir::ElementGenerator& init_value_gen,
-                          HloComputation* reducer);
+  Status EmitRowReduction(
+      int64 depth, int64 height, int64 width, HloInstruction* reduce,
+      const Shape& input_shape,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+      tensorflow::gtl::ArraySlice<HloComputation*> reducers);
 
   // Emits code that reduces a tensor of arbitrary rank to a scalar.
-  Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape,
-                               const llvm_ir::ElementGenerator& input_gen,
-                               const llvm_ir::ElementGenerator& init_value_gen,
-                               HloComputation* reducer);
+  Status EmitReductionToScalar(
+      HloInstruction* reduce, const Shape& input_shape,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+      tensorflow::gtl::ArraySlice<HloComputation*> reducers);
 
   // Figures out whether `reduce` is a row or column reduction, and which
   // dimensions to reduce, and calls either `EmitRowReduction` or
@@ -141,13 +144,16 @@ class IrEmitterUnnested : public IrEmitter {
   // generate elements of the input and the initial value. Other parameters mean
   // the same as for `HandleReduce`.
   //
+  // Multiple reduces can be emitted in the same loop, assuming they have the
+  // same input and output shapes, and the same reduce dimensions.
+  //
   // Prerequisite: `IsReductionToVector(*reduce)`
   Status EmitReductionToVector(
       HloInstruction* reduce, const Shape& input_shape,
-      const llvm_ir::ElementGenerator& input_gen,
-      const llvm_ir::ElementGenerator& init_value_gen,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+      tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
-      HloComputation* reducer);
+      tensorflow::gtl::ArraySlice<HloComputation*> reducers);
 
   // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
   // caller needs to make sure `inst` outlives the lifetime of the returned
@@ -166,7 +172,7 @@ class IrEmitterUnnested : public IrEmitter {
   // Returns a thunk that, given a reduce or select-and-scatter op, initializes
   // its memory to the appropriate initial value.
   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
-      const HloInstruction* hlo);
+      const HloInstruction* hlo, const ShapeIndex& index = {});
 
   // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
   std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
index ec7ca20..3cbb245 100644 (file)
@@ -273,5 +273,112 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
       *result, *Literal::CreateR1<float>({0.0, 4.0, 9.0})));
 }
 
+const char* const kScalarOps = R"(
+    HloModule m
+
+    Add {
+      lhsadd = f32[] parameter(0)
+      rhsadd = f32[] parameter(1)
+      ROOT add = f32[] add(lhsadd, rhsadd)
+    }
+
+    Max {
+      lhsmax = f32[] parameter(0)
+      rhsmax = f32[] parameter(1)
+      ROOT max = f32[] maximum(lhsmax, rhsmax)
+    }
+)";
+
+XLA_TEST_F(MultiOutputFusionTest,
+           DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
+  const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+    fused_reduce {
+      p0 = f32[2,2,2]{2,1,0} parameter(0)
+      c0 = f32[] constant(0)
+      r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
+      mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+      c1 = f32[] constant(5)
+      r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
+      ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+    }
+
+    ENTRY reduce {
+      p = f32[2,2,2]{2,1,0} parameter(0)
+      ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
+                                                        calls=fused_reduce
+    })");
+  auto module =
+      HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+          .ValueOrDie();
+  auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+  TF_ASSERT_OK_AND_ASSIGN(auto result,
+                          Execute(std::move(module), {param.get()}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      *result,
+      *Literal::MakeTupleOwned(Literal::CreateR2<float>({{3, 7}, {11, 15}}),
+                               Literal::CreateR2<float>({{5, 16}, {36, 64}}))));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+           DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
+  const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+    fused_reduce {
+      p0 = f32[2,2,2]{2,1,0} parameter(0)
+      c0 = f32[] constant(0)
+      r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
+      mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+      c1 = f32[] constant(5)
+      r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
+      ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+    }
+
+    ENTRY reduce {
+      p = f32[2,2,2]{2,1,0} parameter(0)
+      ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
+                                                        calls=fused_reduce
+    })");
+  auto module =
+      HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+          .ValueOrDie();
+  auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+  TF_ASSERT_OK_AND_ASSIGN(auto result,
+                          Execute(std::move(module), {param.get()}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      *result, *Literal::MakeTupleOwned(
+                   Literal::CreateR2<float>({{6, 8}, {10, 12}}),
+                   Literal::CreateR2<float>({{25, 36}, {49, 64}}))));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+           DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
+  const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+    fused_reduce {
+      p0 = f32[2,2,2]{2,1,0} parameter(0)
+      c0 = f32[] constant(0)
+      r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
+      mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+      c1 = f32[] constant(5)
+      r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
+      r3 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Add
+      ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3)
+    }
+
+    ENTRY reduce {
+      p = f32[2,2,2]{2,1,0} parameter(0)
+      ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
+                                                        calls=fused_reduce
+    })");
+  auto module =
+      HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+          .ValueOrDie();
+  auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+  TF_ASSERT_OK_AND_ASSIGN(auto result,
+                          Execute(std::move(module), {param.get()}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      *result, *Literal::MakeTupleOwned(Literal::CreateR1<float>({14, 22}),
+                                        Literal::CreateR1<float>({36, 64}),
+                                        Literal::CreateR1<float>({391, 463}))));
+}
+
 }  // namespace
 }  // namespace xla