Merge consecutive broadcast HLO instructions.
authorDimitris Vardoulakis <dimvar@google.com>
Thu, 22 Mar 2018 21:38:41 +0000 (14:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 21:41:16 +0000 (14:41 -0700)
As an optimization, replace consecutive broadcast instructions with a single equivalent broadcast in algebraic simplification.

PiperOrigin-RevId: 190127730

tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc

index 971c293..88f6ff0 100644 (file)
@@ -1121,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
 
 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
   auto operand = broadcast->mutable_operand(0);
+  auto dims = broadcast->dimensions();
   // A degenerate broadcast of a reshape that does not change the number of
   // elements can be replaced by a reshape.
-  if (std::is_sorted(broadcast->dimensions().begin(),
-                     broadcast->dimensions().end()) &&
+  if (std::is_sorted(dims.begin(), dims.end()) &&
       ShapeUtil::ElementsIn(broadcast->shape()) ==
           ShapeUtil::ElementsIn(operand->shape())) {
     VLOG(10) << "transform broadcast(X) -> reshape(X) where "
@@ -1142,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
     VLOG(10) << "transform broadcast(X) -> transpose(X) where "
                 "n(broadcast(X)) == n(X)";
     return ReplaceWithNewInstruction(
-        broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand,
-                                                   broadcast->dimensions()));
+        broadcast,
+        HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
   }
 
   // A broadcast of a reshape which merely inserts 1-sized dimensions can
@@ -1157,7 +1157,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
     if (merely_inserts_or_deletes_1_sized_dimensions &&
         deleted_indices.empty()) {
       std::reverse(inserted_indices.begin(), inserted_indices.end());
-      auto dims = broadcast->dimensions();
       for (auto inserted_index : inserted_indices) {
         dims.erase(dims.begin() + inserted_index);
       }
@@ -1201,6 +1200,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
         return user->ReplaceAllUsesWith(new_broadcast);
       }
     }
+    return Status::OK();
+  }
+
+  // Merge two consecutive broadcasts into a single one.
+  if (operand->opcode() == HloOpcode::kBroadcast) {
+    std::vector<int64> new_dimensions(operand->dimensions().size());
+    for (auto dim : operand->dimensions()) {
+      new_dimensions.push_back(dims[dim]);
+    }
+    return ReplaceWithNewInstruction(
+        broadcast,
+        HloInstruction::CreateBroadcast(
+            broadcast->shape(), operand->mutable_operand(0), new_dimensions));
   }
   return Status::OK();
 }
index 451294e..3b80a82 100644 (file)
@@ -35,6 +35,8 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 
+using ::testing::ElementsAre;
+
 namespace xla {
 namespace {
 
@@ -2462,6 +2464,55 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
               op::DynamicSlice(op::Parameter(), op::Parameter()));
 }
 
+// Test that two consecutive broadcasts can be merged to one.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
+  HloComputation::Builder builder(TestName());
+  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+  HloInstruction* input_array = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
+  HloInstruction* inner_bcast = builder.AddInstruction(
+      HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
+  Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
+  builder.AddInstruction(
+      HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
+
+  auto computation = module().AddEntryComputation(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+                                 non_bitcasting_callback());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+  root = computation->root_instruction();
+  EXPECT_THAT(root, op::Broadcast(op::Constant()));
+  EXPECT_THAT(root->dimensions(), ElementsAre(2));
+}
+
+// Test that two consecutive broadcasts can be merged to one.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
+  HloComputation::Builder builder(TestName());
+  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
+  Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r2f32, "param0"));
+  // The initial dimensions go to places 0 and 2 in the 3-dim array,
+  // and to places 1 and 3 in the 4-dim array,
+  HloInstruction* inner_bcast = builder.AddInstruction(
+      HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
+  Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
+  builder.AddInstruction(
+      HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
+
+  auto computation = module().AddEntryComputation(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+                                 non_bitcasting_callback());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+  root = computation->root_instruction();
+  EXPECT_THAT(root, op::Broadcast(op::Parameter(0)));
+  EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
+}
+
 struct PadReduceWindowEffectiveBroadcastCase {
   std::vector<int64> input_spatials;
   std::vector<int64> symmetric_pad_spatials;