function));
}
- VLOG(10) << "Considering folding Pad: " << operand->ToString()
- << "\ninto reduce-window: " << reduce_window->ToString();
-
// This optimization folds a pad op into reduce_window.
- if (operand->opcode() != HloOpcode::kPad) {
+ HloInstruction* pad;
+ const HloInstruction* convert = nullptr;
+ if (operand->opcode() == HloOpcode::kPad) {
+ pad = operand;
+ } else if (operand->opcode() == HloOpcode::kConvert &&
+ operand->operand(0)->opcode() == HloOpcode::kPad) {
+ convert = operand;
+ pad = operand->mutable_operand(0);
+ } else {
VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
return Status::OK();
}
+ VLOG(10) << "Considering folding Pad: " << pad->ToString()
+ << "\ninto reduce-window: " << reduce_window->ToString()
+ << (convert != nullptr ? tensorflow::strings::StrCat(
+ "\nvia convert: ", convert->ToString())
+ : "");
+
// Do not fold interior padding into ReduceWindow since the backends do not
// support it.
- const PaddingConfig& pad_config = operand->padding_config();
+ const PaddingConfig& pad_config = pad->padding_config();
if (HasInteriorPadding(pad_config)) {
VLOG(10) << "Not folding pad into reduce-window due to interior padding.";
return Status::OK();
// If reduce_window already has padding, the pad value of the pad op and the
// init value of reduce_window must match to allow folding the pad.
- const HloInstruction* pad_value = operand->operand(1);
+ const HloInstruction* pad_value = pad->operand(1);
const HloInstruction* reduce_init_value = reduce_window->operand(1);
if (pad_value != reduce_init_value) {
+ auto literals_are_equivalent = [&] {
+ auto& pad_literal = pad_value->literal();
+ auto& reduce_init_literal = reduce_init_value->literal();
+ if (pad_literal == reduce_init_literal) {
+ return true;
+ }
+ auto converted_pad_literal = pad_literal.ConvertToShape(
+ reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+ if (!converted_pad_literal.ok()) {
+ return false;
+ }
+ return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+ };
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
if (pad_value->opcode() != HloOpcode::kConstant ||
reduce_init_value->opcode() != HloOpcode::kConstant ||
- pad_value->literal() != reduce_init_value->literal()) {
+ !literals_are_equivalent()) {
VLOG(10) << "Not folding pad into reduce-window due to different pad "
"values.";
return Status::OK();
// If the pad puts a single non-identity value in each window that we're
// reducing, then this is a broadcast.
- HloInstruction* pad_operand = operand->mutable_operand(0);
+ HloInstruction* pad_operand = pad->mutable_operand(0);
auto is_effective_broadcast = [&] {
if (window_util::HasStride(window)) {
VLOG(10) << "Window has stride.";
VLOG(10) << "Found window covers a single unpadded element.";
return true;
};
+
+ HloInstruction* new_reduce_window_operand;
+ if (convert != nullptr) {
+ new_reduce_window_operand =
+ computation_->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(pad_operand->shape(),
+ convert->shape().element_type()),
+ pad_operand));
+ } else {
+ new_reduce_window_operand = pad_operand;
+ }
+
if (is_effective_broadcast()) {
VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast.";
auto fadd = [this](std::unique_ptr<HloInstruction> x) {
return ReplaceWithNewInstruction(
reduce_window, HloInstruction::CreateBroadcastSequence(
/*output_shape=*/reduce_window->shape(),
- /*operand=*/pad_operand, fadd));
+ /*operand=*/new_reduce_window_operand, fadd));
}
// Carry out the folding of the pad into reduce_window.
window_dim.set_padding_high(window_dim.padding_high() +
pad_dim.edge_padding_high());
}
+
return ReplaceWithNewInstruction(
reduce_window, HloInstruction::CreateReduceWindow(
/*shape=*/reduce_window->shape(),
- /*operand=*/pad_operand,
+ /*operand=*/new_reduce_window_operand,
/*init_value=*/reduce_window->mutable_operand(1),
/*window=*/new_window,
/*reduce_computation=*/function));
EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
}
+// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
+// ReduceWindow(Convert(op), x).
+TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
+ HloModule module(TestName());
+ HloComputation::Builder builder(TestName());
+
+ // Create operand to the pad.
+ HloInstruction* parameter =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
+
+ // Create the pad.
+ PaddingConfig padding = MakeNoPaddingConfig(4);
+ padding.mutable_dimensions(1)->set_edge_padding_low(1);
+ padding.mutable_dimensions(3)->set_edge_padding_high(2);
+
+ HloInstruction* pad_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
+
+ HloInstruction* convert =
+ builder.AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
+
+ // Create add computation.
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module.AddEmbeddedComputation(builder.Build());
+ }
+
+ // Create the reduce-window.
+ Window window;
+ for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
+ auto* dim = window.add_dimensions();
+ dim->set_size(1);
+ dim->set_padding_low(10);
+ dim->set_padding_high(100);
+ dim->set_window_dilation(1);
+ dim->set_base_dilation(1);
+ }
+ const Shape reduce_window_shape =
+ ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
+ HloInstruction* reduce_init_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction* reduce_window =
+ builder.AddInstruction(HloInstruction::CreateReduceWindow(
+ reduce_window_shape, convert, reduce_init_value, window,
+ add_computation));
+
+ // Build the computation and run the simplifier.
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, reduce_window);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ // Running simplification again should not result in any further changes.
+ ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+
+ // Verify the result
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant()));
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
+ << ShapeUtil::HumanString(root->shape()) << " vs "
+ << ShapeUtil::HumanString(reduce_window_shape);
+ EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
+ EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
+}
+
TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
HloComputation::Builder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});