[XLA] Improve dot strength reductions to support transposes of the right and
authorBlake Hechtman <blakehechtman@google.com>
Mon, 11 Dec 2017 16:09:11 +0000 (08:09 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 11 Dec 2017 16:12:31 +0000 (08:12 -0800)
left hand side of a dot.

PiperOrigin-RevId: 178619673

tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc

index b1d0345e703b6c0038081ecb557b73270727bacd..2c0d1900eb6108eb8028fd89220758df03746647 100644 (file)
@@ -193,6 +193,33 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
         enable_dot_strength_reduction_(enable_dot_strength_reduction),
         enable_conv_simplification_(enable_conv_simplification) {}
 
+  // Transforms Dots where at least one input is a vector or has a degenerate
+  // dimension and converts it into a multiply and reduce. This should enable
+  // more fusion than leaving the nodes as Dot operations.
+  StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot);
+
+  // Reshapes an instruction to rank 1 if it is not already rank 1.
+  HloInstruction* Flatten(HloInstruction* hlo) {
+    if (ShapeUtil::Rank(hlo->shape()) == 1) {
+      return hlo;
+    }
+    return computation_->AddInstruction(HloInstruction::CreateReshape(
+        ShapeUtil::MakeShape(hlo->shape().element_type(),
+                             {ShapeUtil::ElementsIn(hlo->shape())}),
+        hlo));
+  }
+
+  // Helper method to perform and add reduction in a single dimension.
+  HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
+    HloInstruction* zero = computation_->AddInstruction(
+        HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+    HloComputation* AddReduce_computation = CreateScalarBinaryComputation(
+        computation_->parent(), F32, HloOpcode::kAdd);
+    Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
+    return computation_->AddInstruction(HloInstruction::CreateReduce(
+        shape, hlo, zero, {dim}, AddReduce_computation));
+  }
+
   // Convenience method for replacing an instruction with a bitcast.
   void ReplaceWithBitcast(HloInstruction* instruction);
 
@@ -574,6 +601,116 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
   return Status::OK();
 }
 
+StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
+    HloInstruction* dot) {
+  HloInstruction* lhs = dot->mutable_operand(0);
+  HloInstruction* rhs = dot->mutable_operand(1);
+  int64 lhs_collapsing_dim =
+      dot->dot_dimension_numbers().lhs_contracting_dimensions(0);
+  if (lhs->IsRank2Transpose()) {
+    lhs = lhs->mutable_operand(0);
+    lhs_collapsing_dim = 1 - lhs_collapsing_dim;
+  }
+  const int64 lhs_kept_dim = 1 - lhs_collapsing_dim;
+
+  int64 rhs_collapsing_dim =
+      dot->dot_dimension_numbers().rhs_contracting_dimensions(0);
+  if (rhs->IsRank2Transpose()) {
+    rhs = rhs->mutable_operand(0);
+    rhs_collapsing_dim = 1 - rhs_collapsing_dim;
+  }
+  const int64 rhs_kept_dim = 1 - rhs_collapsing_dim;
+
+  auto reshape_if_necessary = [&](HloInstruction* hlo) {
+    if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
+      return hlo;
+    }
+    return computation_->AddInstruction(
+        HloInstruction::CreateReshape(dot->shape(), hlo));
+  };
+
+  auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
+                              int64 dim) {
+    return computation_->AddInstruction(
+        HloInstruction::CreateBroadcast(shape, hlo, {dim}));
+  };
+
+  auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) {
+    return computation_->AddInstruction(HloInstruction::CreateBinary(
+        local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs));
+  };
+
+  // Strength reduce dot(a[K] , b[K]) =
+  //  reshape(result.shape,
+  //          reduce_sum(multiply(a, b), {0}))
+  if (ShapeUtil::Rank(rhs->shape()) == 1 &&
+      ShapeUtil::Rank(lhs->shape()) == 1) {
+    TF_RETURN_IF_ERROR(
+        ReplaceInstruction(dot, reshape_if_necessary(AddReduce(
+                                    multiply(Flatten(lhs), Flatten(rhs)), 0))));
+    return true;
+  }
+
+  if (ShapeUtil::IsEffectiveScalar(rhs->shape()) &&
+      ShapeUtil::IsEffectiveScalar(lhs->shape())) {
+    TF_RETURN_IF_ERROR(ReplaceInstruction(
+        dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs)))));
+    return true;
+  }
+
+  // Simplify outer product into multiply with implicit broadcasting.
+  //
+  // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
+  if (ShapeUtil::Rank(rhs->shape()) == 2 &&
+      rhs->shape().dimensions(rhs_collapsing_dim) == 1) {
+    TF_RETURN_IF_ERROR(ReplaceInstruction(
+        dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0),
+                      broadcast_to_dim(Flatten(rhs), dot->shape(), 1))));
+    return true;
+  }
+
+  // Strength reduce dot(a[1, K], b) =
+  //    reshape(result.shape,
+  //      reduce_sum(
+  //        multiply(broadcast(reshape(a, [K]), {0}), b),
+  //        {0})
+  //      )
+  //    )
+  if (ShapeUtil::Rank(lhs->shape()) == 1 ||
+      (ShapeUtil::Rank(lhs->shape()) == 2 &&
+       lhs->shape().dimensions(lhs_kept_dim) == 1)) {
+    if (ShapeUtil::Rank(rhs->shape()) == 1) {
+      TF_RETURN_IF_ERROR(ReplaceInstruction(
+          dot,
+          reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0))));
+      return true;
+    }
+    TF_RETURN_IF_ERROR(ReplaceInstruction(
+        dot, reshape_if_necessary(
+                 AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
+                                                     rhs_collapsing_dim),
+                                    rhs),
+                           rhs_collapsing_dim))));
+    return true;
+  }
+
+  // Strength reduce dot(a, b[K, 1]) =
+  //  reshape(result.shape,
+  //    reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
+  //  )
+  if (ShapeUtil::Rank(rhs->shape()) == 1 ||
+      (ShapeUtil::Rank(rhs->shape()) == 2 &&
+       rhs->shape().dimensions(rhs_kept_dim) == 1)) {
+    TF_RETURN_IF_ERROR(ReplaceInstruction(
+        dot, reshape_if_necessary(AddReduce(
+                 multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
+                                                lhs_collapsing_dim)),
+                 lhs_collapsing_dim))));
+    return true;
+  }
+  return false;
+}
+
 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
   auto lhs = dot->mutable_operand(0);
   auto rhs = dot->mutable_operand(1);
@@ -595,6 +732,14 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
         dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
   }
 
+  if (enable_dot_strength_reduction_ && !is_layout_sensitive_) {
+    TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
+                        HandleDotStrengthReduction(dot));
+    if (did_strength_reduction) {
+      return Status::OK();
+    }
+  }
+
   // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
   if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
     DotDimensionNumbers dot_dimension_numbers;
@@ -608,106 +753,6 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
         dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
   }
 
-  if (!enable_dot_strength_reduction_) {
-    return Status::OK();
-  }
-
-  // Simplify outer product into multiply with implicit broadcasting.
-  //
-  // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
-  if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) {
-    return ReplaceWithNewInstruction(
-        dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
-                                          lhs, rhs));
-  }
-
-  // The following graph transformations take Dots where at least one input is a
-  // vector or has a degenerate dimension and converts it into a multiply and
-  // reduce. This should enable more fusion than leaving the nodes as Dot
-  // operations.
-
-  // Strength reduce dot(a[K] , b[K]) =
-  //  reshape(result.shape,
-  //          reduce_sum(multiply(a, b), {0}))
-  if (ShapeUtil::Rank(rhs->shape()) == 1 &&
-      ShapeUtil::Rank(lhs->shape()) == 1) {
-    auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
-        rhs->shape(), HloOpcode::kMultiply, lhs, rhs));
-    HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
-        computation_->parent(), F32, HloOpcode::kAdd);
-    auto zero = computation_->AddInstruction(
-        HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
-    auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
-        ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
-        {0}, add_reduce_computation));
-    return ReplaceWithNewInstruction(
-        dot, HloInstruction::CreateReshape(dot->shape(), reduce));
-  }
-
-  // Strength reduce dot(a[1, K], b) =
-  //    reshape(result.shape,
-  //      reduce_sum(
-  //        multiply(broadcast(reshape(a, [K]), {0}), b),
-  //        {0})
-  //      )
-  //    )
-  if (ShapeUtil::Rank(lhs->shape()) == 1 ||
-      (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) {
-    auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape(
-        ShapeUtil::MakeShape(lhs->shape().element_type(),
-                             {ShapeUtil::ElementsIn(lhs->shape())}),
-        lhs));
-    HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
-        computation_->parent(), F32, HloOpcode::kAdd);
-    auto zero = computation_->AddInstruction(
-        HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
-    HloInstruction* reduce;
-    if (ShapeUtil::Rank(rhs->shape()) == 1) {
-      auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
-          rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
-      reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
-          ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
-          {0}, add_reduce_computation));
-    } else {
-      new_lhs = computation_->AddInstruction(
-          HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0}));
-      auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
-          rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
-
-      reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
-          ShapeUtil::MakeShape(dot->shape().element_type(),
-                               {rhs->shape().dimensions(1)}),
-          multiply, zero, {0}, add_reduce_computation));
-    }
-    return ReplaceWithNewInstruction(
-        dot, HloInstruction::CreateReshape(dot->shape(), reduce));
-  }
-
-  // Strength reduce dot(a, b[K, 1]) =
-  //  reshape(result.shape,
-  //    reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
-  //  )
-  if (ShapeUtil::Rank(rhs->shape()) == 1 ||
-      (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) {
-    auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape(
-        ShapeUtil::MakeShape(rhs->shape().element_type(),
-                             {ShapeUtil::ElementsIn(rhs->shape())}),
-        rhs));
-    new_rhs = computation_->AddInstruction(
-        HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1}));
-    auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
-        lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs));
-    HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
-        computation_->parent(), F32, HloOpcode::kAdd);
-    auto zero = computation_->AddInstruction(
-        HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
-    auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
-        ShapeUtil::MakeShape(dot->shape().element_type(),
-                             {lhs->shape().dimensions(0)}),
-        multiply, zero, {1}, add_reduce_computation));
-    return ReplaceWithNewInstruction(
-        dot, HloInstruction::CreateReshape(dot->shape(), reduce));
-  }
   return Status::OK();
 }
 
index 3d70505f6e68d01706b81088285a9ca43f2080bb..7462e397ff07779c04bce18b68419bff9686dbd5 100644 (file)
@@ -2238,5 +2238,63 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
               op::DynamicSlice(op::Parameter(), op::Parameter()));
 }
 
+class DotStrengthReductionTest
+    : public AlgebraicSimplifierTest,
+      public ::testing::WithParamInterface<
+          ::testing::tuple<int, int, int, bool, bool>> {};
+TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
+  int m, k, n;
+  bool transpose_lhs, transpose_rhs;
+  std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+
+  Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
+  Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
+  Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
+  Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
+  Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+  HloComputation::Builder builder(TestName());
+
+  auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
+      0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
+  if (transpose_lhs) {
+    lhs = builder.AddInstruction(
+        HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
+  }
+  auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
+      1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
+  if (transpose_rhs) {
+    rhs = builder.AddInstruction(
+        HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
+  }
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  builder.AddInstruction(
+      HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+                                 non_bitcasting_callback());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
+  const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
+  const bool computation_should_be_modified =
+      dot_should_be_transformed || (transpose_lhs && transpose_rhs);
+  EXPECT_EQ(changed, computation_should_be_modified);
+  bool has_no_dot = true;
+  for (const auto& hlo : computation->instructions()) {
+    if (hlo->opcode() == HloOpcode::kDot) {
+      has_no_dot = false;
+      break;
+    }
+  }
+  EXPECT_EQ(has_no_dot, dot_should_be_transformed);
+}
+
+INSTANTIATE_TEST_CASE_P(
+    DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
+    ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
+                       ::testing::Values(1, 2), ::testing::Bool(),
+                       ::testing::Bool()));
+
 }  // namespace
 }  // namespace xla