[XLA] Fold reduce-window(convert(pad(X))) into reduce-window(convert(X))
authorDavid Majnemer <majnemer@google.com>
Tue, 27 Mar 2018 22:37:35 +0000 (15:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 22:40:36 +0000 (15:40 -0700)
ReduceWindow operations are done in higher precision to avoid accumulation
error. Convert operations can find their way between a ReduceWindow and a Pad
which can prevent a Pad from combining with a ReduceWindow.

Fix this by looking past the Convert while also checking that the Convert'd
Pad's init value is identical to the reduce-window value.

PiperOrigin-RevId: 190686175

tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc

index f9fabd8..0e4624f 100644 (file)
@@ -1731,18 +1731,29 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
                                   function));
   }
 
-  VLOG(10) << "Considering folding Pad: " << operand->ToString()
-           << "\ninto reduce-window: " << reduce_window->ToString();
-
   // This optimization folds a pad op into reduce_window.
-  if (operand->opcode() != HloOpcode::kPad) {
+  HloInstruction* pad;
+  const HloInstruction* convert = nullptr;
+  if (operand->opcode() == HloOpcode::kPad) {
+    pad = operand;
+  } else if (operand->opcode() == HloOpcode::kConvert &&
+             operand->operand(0)->opcode() == HloOpcode::kPad) {
+    convert = operand;
+    pad = operand->mutable_operand(0);
+  } else {
     VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
     return Status::OK();
   }
 
+  VLOG(10) << "Considering folding Pad: " << pad->ToString()
+           << "\ninto reduce-window: " << reduce_window->ToString()
+           << (convert != nullptr ? tensorflow::strings::StrCat(
+                                        "\nvia convert: ", convert->ToString())
+                                  : "");
+
   // Do not fold interior padding into ReduceWindow since the backends do not
   // support it.
-  const PaddingConfig& pad_config = operand->padding_config();
+  const PaddingConfig& pad_config = pad->padding_config();
   if (HasInteriorPadding(pad_config)) {
     VLOG(10) << "Not folding pad into reduce-window due to interior padding.";
     return Status::OK();
@@ -1750,14 +1761,27 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
 
   // If reduce_window already has padding, the pad value of the pad op and the
   // init value of reduce_window must match to allow folding the pad.
-  const HloInstruction* pad_value = operand->operand(1);
+  const HloInstruction* pad_value = pad->operand(1);
   const HloInstruction* reduce_init_value = reduce_window->operand(1);
   if (pad_value != reduce_init_value) {
+    auto literals_are_equivalent = [&] {
+      auto& pad_literal = pad_value->literal();
+      auto& reduce_init_literal = reduce_init_value->literal();
+      if (pad_literal == reduce_init_literal) {
+        return true;
+      }
+      auto converted_pad_literal = pad_literal.ConvertToShape(
+          reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+      if (!converted_pad_literal.ok()) {
+        return false;
+      }
+      return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+    };
     // The pad value is usually a constant, so we handle that case and do not
     // try to get more fancy about proving equivalence in cases beyond that.
     if (pad_value->opcode() != HloOpcode::kConstant ||
         reduce_init_value->opcode() != HloOpcode::kConstant ||
-        pad_value->literal() != reduce_init_value->literal()) {
+        !literals_are_equivalent()) {
       VLOG(10) << "Not folding pad into reduce-window due to different pad "
                   "values.";
       return Status::OK();
@@ -1766,7 +1790,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
 
   // If the pad puts a single non-identity value in each window that we're
   // reducing, then this is a broadcast.
-  HloInstruction* pad_operand = operand->mutable_operand(0);
+  HloInstruction* pad_operand = pad->mutable_operand(0);
   auto is_effective_broadcast = [&] {
     if (window_util::HasStride(window)) {
       VLOG(10) << "Window has stride.";
@@ -1810,6 +1834,18 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
     VLOG(10) << "Found window covers a single unpadded element.";
     return true;
   };
+
+  HloInstruction* new_reduce_window_operand;
+  if (convert != nullptr) {
+    new_reduce_window_operand =
+        computation_->AddInstruction(HloInstruction::CreateConvert(
+            ShapeUtil::ChangeElementType(pad_operand->shape(),
+                                         convert->shape().element_type()),
+            pad_operand));
+  } else {
+    new_reduce_window_operand = pad_operand;
+  }
+
   if (is_effective_broadcast()) {
     VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast.";
     auto fadd = [this](std::unique_ptr<HloInstruction> x) {
@@ -1818,7 +1854,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
     return ReplaceWithNewInstruction(
         reduce_window, HloInstruction::CreateBroadcastSequence(
                            /*output_shape=*/reduce_window->shape(),
-                           /*operand=*/pad_operand, fadd));
+                           /*operand=*/new_reduce_window_operand, fadd));
   }
 
   // Carry out the folding of the pad into reduce_window.
@@ -1835,10 +1871,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
     window_dim.set_padding_high(window_dim.padding_high() +
                                 pad_dim.edge_padding_high());
   }
+
   return ReplaceWithNewInstruction(
       reduce_window, HloInstruction::CreateReduceWindow(
                          /*shape=*/reduce_window->shape(),
-                         /*operand=*/pad_operand,
+                         /*operand=*/new_reduce_window_operand,
                          /*init_value=*/reduce_window->mutable_operand(1),
                          /*window=*/new_window,
                          /*reduce_computation=*/function));
index 3b80a82..20c5495 100644 (file)
@@ -2338,6 +2338,91 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
   EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
 }
 
+// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
+// ReduceWindow(Convert(op), x).
+TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
+  HloModule module(TestName());
+  HloComputation::Builder builder(TestName());
+
+  // Create operand to the pad.
+  HloInstruction* parameter =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
+
+  // Create the pad.
+  PaddingConfig padding = MakeNoPaddingConfig(4);
+  padding.mutable_dimensions(1)->set_edge_padding_low(1);
+  padding.mutable_dimensions(3)->set_edge_padding_high(2);
+
+  HloInstruction* pad_value = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+  HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
+      ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
+
+  HloInstruction* convert =
+      builder.AddInstruction(HloInstruction::CreateConvert(
+          ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
+
+  // Create add computation.
+  HloComputation* add_computation = nullptr;
+  {
+    HloComputation::Builder builder(TestName() + ".add");
+    const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+    HloInstruction* p0 = builder.AddInstruction(
+        HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+    HloInstruction* p1 = builder.AddInstruction(
+        HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+    builder.AddInstruction(
+        HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+    add_computation = module.AddEmbeddedComputation(builder.Build());
+  }
+
+  // Create the reduce-window.
+  Window window;
+  for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
+    auto* dim = window.add_dimensions();
+    dim->set_size(1);
+    dim->set_padding_low(10);
+    dim->set_padding_high(100);
+    dim->set_window_dilation(1);
+    dim->set_base_dilation(1);
+  }
+  const Shape reduce_window_shape =
+      ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
+  HloInstruction* reduce_init_value = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+  HloInstruction* reduce_window =
+      builder.AddInstruction(HloInstruction::CreateReduceWindow(
+          reduce_window_shape, convert, reduce_init_value, window,
+          add_computation));
+
+  // Build the computation and run the simplifier.
+  auto computation = module.AddEntryComputation(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root, reduce_window);
+  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+                                 non_bitcasting_callback());
+  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+  // Running simplification again should not result in any further changes.
+  ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+
+  // Verify the result
+  root = computation->root_instruction();
+  EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant()));
+  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
+      << ShapeUtil::HumanString(root->shape()) << " vs "
+      << ShapeUtil::HumanString(reduce_window_shape);
+  EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
+  EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
+  EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
+  EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
+  EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
+  EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
+  EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
+  EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
+}
+
 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
   HloComputation::Builder builder(TestName());
   const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});