Automated g4 rollback of changelist 190127730
authorMark Heffernan <meheff@google.com>
Thu, 22 Mar 2018 22:49:57 +0000 (15:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 22:52:47 +0000 (15:52 -0700)
PiperOrigin-RevId: 190139303

tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc

index 88f6ff0..971c293 100644 (file)
@@ -1121,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
 
 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
   auto operand = broadcast->mutable_operand(0);
-  auto dims = broadcast->dimensions();
   // A degenerate broadcast of a reshape that does not change the number of
   // elements can be replaced by a reshape.
-  if (std::is_sorted(dims.begin(), dims.end()) &&
+  if (std::is_sorted(broadcast->dimensions().begin(),
+                     broadcast->dimensions().end()) &&
       ShapeUtil::ElementsIn(broadcast->shape()) ==
           ShapeUtil::ElementsIn(operand->shape())) {
     VLOG(10) << "transform broadcast(X) -> reshape(X) where "
@@ -1142,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
     VLOG(10) << "transform broadcast(X) -> transpose(X) where "
                 "n(broadcast(X)) == n(X)";
     return ReplaceWithNewInstruction(
-        broadcast,
-        HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
+        broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand,
+                                                   broadcast->dimensions()));
   }
 
   // A broadcast of a reshape which merely inserts 1-sized dimensions can
@@ -1157,6 +1157,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
     if (merely_inserts_or_deletes_1_sized_dimensions &&
         deleted_indices.empty()) {
       std::reverse(inserted_indices.begin(), inserted_indices.end());
+      auto dims = broadcast->dimensions();
       for (auto inserted_index : inserted_indices) {
         dims.erase(dims.begin() + inserted_index);
       }
@@ -1200,19 +1201,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
         return user->ReplaceAllUsesWith(new_broadcast);
       }
     }
-    return Status::OK();
-  }
-
-  // Merge two consecutive broadcasts into a single one.
-  if (operand->opcode() == HloOpcode::kBroadcast) {
-    std::vector<int64> new_dimensions(operand->dimensions().size());
-    for (auto dim : operand->dimensions()) {
-      new_dimensions.push_back(dims[dim]);
-    }
-    return ReplaceWithNewInstruction(
-        broadcast,
-        HloInstruction::CreateBroadcast(
-            broadcast->shape(), operand->mutable_operand(0), new_dimensions));
   }
   return Status::OK();
 }
index 3b80a82..451294e 100644 (file)
@@ -35,8 +35,6 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 
-using ::testing::ElementsAre;
-
 namespace xla {
 namespace {
 
@@ -2464,55 +2462,6 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
               op::DynamicSlice(op::Parameter(), op::Parameter()));
 }
 
-// Test that two consecutive broadcasts can be merged to one.
-TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
-  HloComputation::Builder builder(TestName());
-  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
-  HloInstruction* input_array = builder.AddInstruction(
-      HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
-  HloInstruction* inner_bcast = builder.AddInstruction(
-      HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
-  Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
-  builder.AddInstruction(
-      HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
-
-  auto computation = module().AddEntryComputation(builder.Build());
-  HloInstruction* root = computation->root_instruction();
-  EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
-  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
-                                 non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
-  root = computation->root_instruction();
-  EXPECT_THAT(root, op::Broadcast(op::Constant()));
-  EXPECT_THAT(root->dimensions(), ElementsAre(2));
-}
-
-// Test that two consecutive broadcasts can be merged to one.
-TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
-  HloComputation::Builder builder(TestName());
-  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
-  Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
-  HloInstruction* param0 = builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r2f32, "param0"));
-  // The initial dimensions go to places 0 and 2 in the 3-dim array,
-  // and to places 1 and 3 in the 4-dim array,
-  HloInstruction* inner_bcast = builder.AddInstruction(
-      HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
-  Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
-  builder.AddInstruction(
-      HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
-
-  auto computation = module().AddEntryComputation(builder.Build());
-  HloInstruction* root = computation->root_instruction();
-  EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
-  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
-                                 non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
-  root = computation->root_instruction();
-  EXPECT_THAT(root, op::Broadcast(op::Parameter(0)));
-  EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
-}
-
 struct PadReduceWindowEffectiveBroadcastCase {
   std::vector<int64> input_spatials;
   std::vector<int64> symmetric_pad_spatials;