Status HandleSubtract(HloInstruction* sub) override;
+ Status HandleMap(HloInstruction* map) override;
+
Status HandleMaximum(HloInstruction* maximum) override;
Status HandleMinimum(HloInstruction* minimum) override;
return true;
}
+Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
+ auto* map_computation = map->to_apply();
+ auto* map_root = map_computation->root_instruction();
+ if (map_root->opcode() == HloOpcode::kParameter) {
+ ReplaceInstructionIfSameShape(
+ map, map->mutable_operand(map_root->parameter_number()));
+ return Status::OK();
+ }
+ if (map_root->opcode() == HloOpcode::kConstant) {
+ if (!ShapeUtil::IsScalar(map_root->shape())) {
+ return Status::OK();
+ }
+ auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
+ if (ShapeUtil::IsScalar(map->shape())) {
+ return ReplaceWithNewInstruction(map, std::move(clone));
+ }
+ return ReplaceWithNewInstruction(
+ map,
+ HloInstruction::CreateBroadcast(
+ map->shape(), computation_->AddInstruction(std::move(clone)), {}));
+ }
+ std::vector<HloInstruction*> new_operands;
+ for (auto* root_operand : map_root->operands()) {
+ if (root_operand->opcode() != HloOpcode::kParameter) {
+ return Status::OK();
+ }
+ new_operands.push_back(
+ map->mutable_operand(root_operand->parameter_number()));
+ }
+ auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
+ return ReplaceWithNewInstruction(map, std::move(clone));
+}
+
Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
// Match the following tree:
// min_operand operand
EXPECT_EQ(root, param0);
}
+TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
+ HloComputation::Builder builder(TestName());
+ // Create add computation.
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module().AddEmbeddedComputation(builder.Build());
+ }
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kMap);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Add(param0, zero));
+}
+
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
HloComputation::Builder builder(TestName());