Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
auto operand = broadcast->mutable_operand(0);
+ auto dims = broadcast->dimensions();
// A degenerate broadcast of a reshape that does not change the number of
// elements can be replaced by a reshape.
- if (std::is_sorted(broadcast->dimensions().begin(),
- broadcast->dimensions().end()) &&
+ if (std::is_sorted(dims.begin(), dims.end()) &&
ShapeUtil::ElementsIn(broadcast->shape()) ==
ShapeUtil::ElementsIn(operand->shape())) {
VLOG(10) << "transform broadcast(X) -> reshape(X) where "
VLOG(10) << "transform broadcast(X) -> transpose(X) where "
"n(broadcast(X)) == n(X)";
return ReplaceWithNewInstruction(
- broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand,
- broadcast->dimensions()));
+ broadcast,
+ HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
}
// A broadcast of a reshape which merely inserts 1-sized dimensions can
if (merely_inserts_or_deletes_1_sized_dimensions &&
deleted_indices.empty()) {
std::reverse(inserted_indices.begin(), inserted_indices.end());
- auto dims = broadcast->dimensions();
for (auto inserted_index : inserted_indices) {
dims.erase(dims.begin() + inserted_index);
}
return user->ReplaceAllUsesWith(new_broadcast);
}
}
+ return Status::OK();
+ }
+
+ // Merge two consecutive broadcasts into a single one.
+ if (operand->opcode() == HloOpcode::kBroadcast) {
+ std::vector<int64> new_dimensions(operand->dimensions().size());
+ for (auto dim : operand->dimensions()) {
+ new_dimensions.push_back(dims[dim]);
+ }
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateBroadcast(
+ broadcast->shape(), operand->mutable_operand(0), new_dimensions));
}
return Status::OK();
}
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
+using ::testing::ElementsAre;
+
namespace xla {
namespace {
op::DynamicSlice(op::Parameter(), op::Parameter()));
}
+// Test that two consecutive broadcasts can be merged to one.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloInstruction* input_array = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
+ HloInstruction* inner_bcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Broadcast(op::Constant()));
+ EXPECT_THAT(root->dimensions(), ElementsAre(2));
+}
+
+// Test that two consecutive broadcasts can be merged to one.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
+ // The initial dimensions go to places 0 and 2 in the 3-dim array,
+ // and to places 1 and 3 in the 4-dim array,
+ HloInstruction* inner_bcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Broadcast(op::Parameter(0)));
+ EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
+}
+
struct PadReduceWindowEffectiveBroadcastCase {
std::vector<int64> input_spatials;
std::vector<int64> symmetric_pad_spatials;