From: David Majnemer Date: Tue, 27 Mar 2018 22:37:35 +0000 (-0700) Subject: [XLA] Fold reduce-window(convert(pad(X))) into reduce-window(convert(X)) X-Git-Tag: tflite-v0.1.7~67^2^2~88 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=05ddf373980fae94a2c73cf93161332484e102fd;p=platform%2Fupstream%2Ftensorflow.git [XLA] Fold reduce-window(convert(pad(X))) into reduce-window(convert(X)) ReduceWindow operations are done in higher precision to avoid accumulation error. Convert operations can find their way between a ReduceWindow and a Pad which can prevent a Pad from combining with a ReduceWindow. Fix this by looking past the Convert while also checking that the Convert'd Pad's init value is identical to the reduce-window value. PiperOrigin-RevId: 190686175 --- diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index f9fabd8..0e4624f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1731,18 +1731,29 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } - VLOG(10) << "Considering folding Pad: " << operand->ToString() - << "\ninto reduce-window: " << reduce_window->ToString(); - // This optimization folds a pad op into reduce_window. - if (operand->opcode() != HloOpcode::kPad) { + HloInstruction* pad; + const HloInstruction* convert = nullptr; + if (operand->opcode() == HloOpcode::kPad) { + pad = operand; + } else if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->opcode() == HloOpcode::kPad) { + convert = operand; + pad = operand->mutable_operand(0); + } else { VLOG(10) << "Not folding pad into reduce-window as there is no pad."; return Status::OK(); } + VLOG(10) << "Considering folding Pad: " << pad->ToString() + << "\ninto reduce-window: " << reduce_window->ToString() + << (convert != nullptr ? tensorflow::strings::StrCat( + "\nvia convert: ", convert->ToString()) + : ""); + // Do not fold interior padding into ReduceWindow since the backends do not // support it. - const PaddingConfig& pad_config = operand->padding_config(); + const PaddingConfig& pad_config = pad->padding_config(); if (HasInteriorPadding(pad_config)) { VLOG(10) << "Not folding pad into reduce-window due to interior padding."; return Status::OK(); @@ -1750,14 +1761,27 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If reduce_window already has padding, the pad value of the pad op and the // init value of reduce_window must match to allow folding the pad. - const HloInstruction* pad_value = operand->operand(1); + const HloInstruction* pad_value = pad->operand(1); const HloInstruction* reduce_init_value = reduce_window->operand(1); if (pad_value != reduce_init_value) { + auto literals_are_equivalent = [&] { + auto& pad_literal = pad_value->literal(); + auto& reduce_init_literal = reduce_init_value->literal(); + if (pad_literal == reduce_init_literal) { + return true; + } + auto converted_pad_literal = pad_literal.ConvertToShape( + reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + if (!converted_pad_literal.ok()) { + return false; + } + return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - pad_value->literal() != reduce_init_value->literal()) { + !literals_are_equivalent()) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); @@ -1766,7 +1790,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If the pad puts a single non-identity value in each window that we're // reducing, then this is a broadcast. - HloInstruction* pad_operand = operand->mutable_operand(0); + HloInstruction* pad_operand = pad->mutable_operand(0); auto is_effective_broadcast = [&] { if (window_util::HasStride(window)) { VLOG(10) << "Window has stride."; @@ -1810,6 +1834,18 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Found window covers a single unpadded element."; return true; }; + + HloInstruction* new_reduce_window_operand; + if (convert != nullptr) { + new_reduce_window_operand = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad_operand->shape(), + convert->shape().element_type()), + pad_operand)); + } else { + new_reduce_window_operand = pad_operand; + } + if (is_effective_broadcast()) { VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; auto fadd = [this](std::unique_ptr x) { @@ -1818,7 +1854,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcastSequence( /*output_shape=*/reduce_window->shape(), - /*operand=*/pad_operand, fadd)); + /*operand=*/new_reduce_window_operand, fadd)); } // Carry out the folding of the pad into reduce_window. @@ -1835,10 +1871,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( window_dim.set_padding_high(window_dim.padding_high() + pad_dim.edge_padding_high()); } + return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateReduceWindow( /*shape=*/reduce_window->shape(), - /*operand=*/pad_operand, + /*operand=*/new_reduce_window_operand, /*init_value=*/reduce_window->mutable_operand(1), /*window=*/new_window, /*reduce_computation=*/function)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 3b80a82..20c5495 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2338,6 +2338,91 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); } +// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to +// ReduceWindow(Convert(op), x). +TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create operand to the pad. + HloInstruction* parameter = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0")); + + // Create the pad. + PaddingConfig padding = MakeNoPaddingConfig(4); + padding.mutable_dimensions(1)->set_edge_padding_low(1); + padding.mutable_dimensions(3)->set_edge_padding_high(2); + + HloInstruction* pad_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding)); + + HloInstruction* convert = + builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad->shape(), F32), pad)); + + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module.AddEmbeddedComputation(builder.Build()); + } + + // Create the reduce-window. + Window window; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(10); + dim->set_padding_high(100); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + const Shape reduce_window_shape = + ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + HloInstruction* reduce_init_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window_shape, convert, reduce_init_value, window, + add_computation)); + + // Build the computation and run the simplifier. + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reduce_window); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + // Running simplification again should not result in any further changes. + ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + + // Verify the result + root = computation->root_instruction(); + EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) + << ShapeUtil::HumanString(root->shape()) << " vs " + << ShapeUtil::HumanString(reduce_window_shape); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); + EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); +} + TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { HloComputation::Builder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});