[XLA] Remove maps with a single instruction
authorDavid Majnemer <majnemer@google.com>
Thu, 24 May 2018 22:45:25 +0000 (15:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 22:47:44 +0000 (15:47 -0700)
These maps aren't really pulling their weight, fold them to the instruction
that they compute.

PiperOrigin-RevId: 197967117

tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc

index f732ed8..c65c91e 100644 (file)
@@ -157,6 +157,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
 
   Status HandleSubtract(HloInstruction* sub) override;
 
+  Status HandleMap(HloInstruction* map) override;
+
   Status HandleMaximum(HloInstruction* maximum) override;
   Status HandleMinimum(HloInstruction* minimum) override;
 
@@ -2188,6 +2190,39 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
   return true;
 }
 
+Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
+  auto* map_computation = map->to_apply();
+  auto* map_root = map_computation->root_instruction();
+  if (map_root->opcode() == HloOpcode::kParameter) {
+    ReplaceInstructionIfSameShape(
+        map, map->mutable_operand(map_root->parameter_number()));
+    return Status::OK();
+  }
+  if (map_root->opcode() == HloOpcode::kConstant) {
+    if (!ShapeUtil::IsScalar(map_root->shape())) {
+      return Status::OK();
+    }
+    auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
+    if (ShapeUtil::IsScalar(map->shape())) {
+      return ReplaceWithNewInstruction(map, std::move(clone));
+    }
+    return ReplaceWithNewInstruction(
+        map,
+        HloInstruction::CreateBroadcast(
+            map->shape(), computation_->AddInstruction(std::move(clone)), {}));
+  }
+  std::vector<HloInstruction*> new_operands;
+  for (auto* root_operand : map_root->operands()) {
+    if (root_operand->opcode() != HloOpcode::kParameter) {
+      return Status::OK();
+    }
+    new_operands.push_back(
+        map->mutable_operand(root_operand->parameter_number()));
+  }
+  auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
+  return ReplaceWithNewInstruction(map, std::move(clone));
+}
+
 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
   // Match the following tree:
   //          min_operand     operand
index 4e08287..d5f0afe 100644 (file)
@@ -143,6 +143,39 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
   EXPECT_EQ(root, param0);
 }
 
+TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
+  HloComputation::Builder builder(TestName());
+  // Create add computation.
+  HloComputation* add_computation = nullptr;
+  {
+    HloComputation::Builder builder(TestName() + ".add");
+    const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+    HloInstruction* p0 = builder.AddInstruction(
+        HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+    HloInstruction* p1 = builder.AddInstruction(
+        HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+    builder.AddInstruction(
+        HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+    add_computation = module().AddEmbeddedComputation(builder.Build());
+  }
+  Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r2f32, "param0"));
+  HloInstruction* zero = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+  builder.AddInstruction(
+      HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation));
+
+  auto computation = module().AddEntryComputation(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kMap);
+  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+                                 non_bitcasting_callback());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+  root = computation->root_instruction();
+  EXPECT_THAT(root, op::Add(param0, zero));
+}
+
 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
   HloComputation::Builder builder(TestName());