Re-land: Optimize dot(DynamicSlice(ConstA), ConstantB) by memoizing dot(ConstA, ConstB)
authorAlina Sbirlea <asbirlea@google.com>
Tue, 8 May 2018 18:54:03 +0000 (11:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 22:42:09 +0000 (15:42 -0700)
Make transformation when ConstA and ConstB are 2D, and DynamicSlice is slicing a full row, column respectively.
Handle:
dot(DynamicSlice(Index, ConstA), ConstB) => DynamicSlice(Index, dot*(ConstA, ConstB));
and
dot(ConstA, DynamicSlice(Index, ConstB)) => DynamicSlice(Index, dot*(ConstA, ConstB));

Reason to roll forward: Previous issue of getting out of memory errors when generating LLVM constants was resolved by CSE-ing constants before allocation.

PiperOrigin-RevId: 195853680

tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
tensorflow/compiler/xla/tests/dot_operation_test.cc

index 8e785de..4ec79a0 100644 (file)
@@ -291,6 +291,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
       const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
       HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
 
+  StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
+
   // Current HloComputation instance the AlgebraicSimplifierVisitor is
   // traversing.
   HloComputation* computation_;
@@ -912,6 +914,134 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
   return add_result;
 }
 
+StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
+    HloInstruction* dot) {
+  const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+  if (dnums.lhs_contracting_dimensions_size() != 1 ||
+      dnums.rhs_contracting_dimensions_size() != 1 ||
+      dnums.lhs_batch_dimensions_size() != 0 ||
+      dnums.rhs_batch_dimensions_size() != 0 ||
+      dot->shape().dimensions_size() != 2) {  // dot output 2D
+    VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
+    return nullptr;
+  }
+
+  // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
+  // Currently a Gather is a DynamicSlice.
+  auto is_dynamic_slice_constant_combination =
+      [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
+        // First operand is a DynamicSlice(Constant).
+        if (a->opcode() != HloOpcode::kDynamicSlice) {
+          return false;
+        }
+        auto* dynamic_slice_op = a->operand(0);
+        if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
+          return false;
+        }
+        // Second operand is a Constant.
+        if (b->opcode() != HloOpcode::kConstant) {
+          return false;
+        }
+        // The DynamicSlice output is a vector.
+        const Shape& dynamic_slice_shape = a->shape();
+        if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
+          return false;
+        }
+        // Constant size is the same before and after slice in the contracting
+        // dimension, otherwise we either must precompute for all possible slice
+        // indices or dot is invalid.
+        const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
+        if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
+            dynamic_slice_shape.dimensions(a_contracting_dimension)) {
+          return false;
+        }
+        return true;
+      };
+
+  HloInstruction* lhs = dot->mutable_operand(0);
+  HloInstruction* rhs = dot->mutable_operand(1);
+  int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
+  int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
+
+  if (!is_dynamic_slice_constant_combination(
+          lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
+      !is_dynamic_slice_constant_combination(
+          rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
+    VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
+                "dot(ctB, DS(ctA)), where the two constants have equal "
+                "contracting dimensions.";
+    return nullptr;
+  }
+
+  // LHS is DynamicSlice:
+  // input: dot(DS(ctA), ctB))
+  // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
+  // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
+  // output: DS(dot(ctA, ctB))
+  // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
+
+  // RHS is DynamicSlice:
+  // input: dot(ctA, DS(ctB))
+  // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
+  // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
+  // output: DS(dot(ctA, ctB))
+  // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
+
+  bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
+
+  // ctA:
+  HloInstruction* left_operand =
+      lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
+  // ctB:
+  HloInstruction* right_operand =
+      lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
+  // Build ctA x ctB.
+  const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
+  const int n =
+      right_operand->shape().dimensions(1 - rhs_contracting_dimension);
+  auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
+  auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
+      memoized_shape, left_operand, right_operand, dnums));
+  // Get pair {start, 0} or {0, start}.
+  HloInstruction* original_start_indices =
+      lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
+  // Position of start:
+  int index_of_non_zero_start = lhs_is_dynamic_slice
+                                    ? 1 - lhs_contracting_dimension
+                                    : 1 - rhs_contracting_dimension;
+  // Position of zero:
+  int index_of_zero_start = 1 - index_of_non_zero_start;
+
+  // Slice out start and 0 components and reorder if necessary.
+  auto indices_type = original_start_indices->shape().element_type();
+  Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
+  Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
+  HloInstruction* non_zero_start =
+      computation_->AddInstruction(HloInstruction::CreateSlice(
+          s_shape, original_start_indices, {index_of_non_zero_start},
+          {index_of_non_zero_start + 1}, {1}));
+  HloInstruction* zero_start =
+      computation_->AddInstruction(HloInstruction::CreateSlice(
+          s_shape, original_start_indices, {index_of_zero_start},
+          {index_of_zero_start + 1}, {1}));
+  HloInstruction* new_start_indices =
+      lhs_is_dynamic_slice
+          ? computation_->AddInstruction(HloInstruction::CreateConcatenate(
+                d_shape, {non_zero_start, zero_start}, 0))
+          : computation_->AddInstruction(HloInstruction::CreateConcatenate(
+                d_shape, {zero_start, non_zero_start}, 0));
+
+  // Build DynamicSlice(ctA x ctB).
+  const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
+  const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
+  auto* memoized_lookup =
+      computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
+          dot->shape(), memoized_inst, new_start_indices,
+          {new_slice_m, new_slice_n}));
+
+  return memoized_lookup;
+}
+
 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
   HloInstruction *lhs, *rhs;
   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
@@ -941,6 +1071,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
     return ReplaceInstruction(dot, dot_of_concat_optimized);
   }
 
+  // Simplify dot(ConstA, Gather(Index, ConstB)) to:
+  // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
+  // batched version of dot.
+  TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
+                      OptimizeDotOfGather(dot));
+  if (dot_of_gather_optimized) {
+    VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
+                "gather(i, dot*(constA, constB))";
+    return ReplaceInstruction(dot, dot_of_gather_optimized);
+  }
+
   if (enable_dot_strength_reduction_ && !is_layout_sensitive_) {
     TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
                         HandleDotStrengthReduction(dot));
index d0c99bf..4e08287 100644 (file)
@@ -2963,5 +2963,208 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
 INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation,
                         DotOfConcatSimplificationTest,
                         ::testing::ValuesIn(kDotOfConcatTestSpecs));
+
+struct DotOfGatherTestSpec {
+  int64 m;
+  int64 k;
+  int64 n;
+  int s;      // start index for dynamic slice on the non-contracting dimension
+  int64 lcd;  // left contracting dimension
+  int64 rcd;  // right contracting dimension
+  bool neg;   // is negative testcase
+};
+
+class DotOfGatherSimplificationTest
+    : public HloVerifiedTestBase,
+      public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
+
+// input: dot(DS(ctA), ctB))
+// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
+// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
+// output: DS(dot(ctA, ctB))
+// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
+TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
+  HloComputation::Builder builder(TestName());
+
+  DotOfGatherTestSpec spec = GetParam();
+
+  ASSERT_LE(spec.s, spec.m);
+
+  // For negative tests, increase k of the dynamic slice argument to prevent the
+  // optimization (constants ctA, ctB must have equal contracting dimensions).
+  int64 k_increase = spec.neg ? 5 : 0;
+  int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
+  int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
+  Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
+  auto* lhs = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+          /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
+          /*cols=*/lhs_cols)));
+
+  int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
+  int32 start_col = (spec.lcd == 0) ? spec.s : 0;
+  const auto start_indices =
+      builder.AddInstruction(HloInstruction::CreateConstant(
+          Literal::CreateR1<int32>({start_row, start_col})));
+  int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
+  int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
+  Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
+  auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
+      ds_shape, lhs, start_indices, {slice_row_size, slice_col_size}));
+
+  int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
+  int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
+  Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
+  auto* rhs = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+          /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
+          /*cols=*/rhs_cols)));
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
+  dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
+
+  int64 dot_row_size = 1;
+  int64 dot_col_size = spec.n;
+  Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
+  builder.AddInstruction(
+      HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+
+  auto computation = module().AddEntryComputation(builder.Build());
+  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+                                 non_bitcasting_callback());
+  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
+  ASSERT_TRUE(run_successful);
+  EXPECT_TRUE(
+      ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
+
+  if (spec.neg) {
+    EXPECT_NE(computation->root_instruction()->opcode(),
+              HloOpcode::kDynamicSlice);
+  } else {
+    EXPECT_THAT(computation->root_instruction(),
+                op::DynamicSlice(op::Dot(op::Constant(), op::Constant()),
+                                 op::Concatenate()));
+  }
+}
+
+// input: dot(ctA, DS(ctB))
+// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
+// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
+// output: DS(dot(ctA, ctB))
+// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
+TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
+  HloComputation::Builder builder(TestName());
+
+  DotOfGatherTestSpec spec = GetParam();
+
+  ASSERT_LE(spec.s, spec.n);
+
+  int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
+  int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
+  Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
+  auto* lhs = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+          /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
+          /*cols=*/lhs_cols)));
+
+  // For negative tests increase k of the dynamic slice argument to prevent the
+  // optimization
+  int64 k_increase = spec.neg ? 5 : 0;
+  int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
+  int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
+  Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
+  auto* rhs = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+          /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
+          /*cols=*/rhs_cols)));
+
+  int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
+  int32 start_col = (spec.rcd == 0) ? spec.s : 0;
+  const auto start_indices =
+      builder.AddInstruction(HloInstruction::CreateConstant(
+          Literal::CreateR1<int32>({start_row, start_col})));
+  int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
+  int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
+  Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
+  auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
+      ds_shape, rhs, start_indices, {slice_row_size, slice_col_size}));
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
+  dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
+
+  int64 dot_row_size = spec.m;
+  int64 dot_col_size = 1;
+  Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
+  builder.AddInstruction(
+      HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+
+  auto computation = module().AddEntryComputation(builder.Build());
+  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+                                 non_bitcasting_callback());
+  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
+  ASSERT_TRUE(run_successful);
+  EXPECT_TRUE(
+      ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
+
+  if (spec.neg) {
+    EXPECT_NE(computation->root_instruction()->opcode(),
+              HloOpcode::kDynamicSlice);
+  } else {
+    EXPECT_THAT(computation->root_instruction(),
+                op::DynamicSlice(op::Dot(op::Constant(), op::Constant()),
+                                 op::Concatenate()));
+  }
+}
+
+std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
+  std::vector<DotOfGatherTestSpec> positives = {
+      // "Classical dot", i.e. matrix multiply:
+      {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
+       /*neg=*/false},
+      {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
+       /*neg=*/false},
+      {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
+       /*neg=*/false},
+      // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
+      // dot(ct, ct) before DotOfGather optimization kicks in.
+      // Contract on rows:
+      {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
+       /*neg=*/false},
+      {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
+       /*neg=*/false},
+      {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
+       /*neg=*/false},
+      // Reverse matrix multiply:
+      {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
+       /*neg=*/false},
+      {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
+       /*neg=*/false},
+      {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
+       /*neg=*/false},
+      // Contract on columns:
+      {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
+       /*neg=*/false},
+      {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
+       /*neg=*/false},
+      {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
+       /*neg=*/false},
+  };
+  std::vector<DotOfGatherTestSpec> all;
+  for (int i = 0; i < positives.size(); i++) {
+    DotOfGatherTestSpec positive_test = positives[i];
+    all.push_back(positive_test);
+    DotOfGatherTestSpec negative_test = positive_test;
+    negative_test.neg = true;
+    all.push_back(negative_test);
+  }
+  return all;
+}
+
+INSTANTIATE_TEST_CASE_P(
+    DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
+    ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
+
 }  // namespace
 }  // namespace xla
index 6b3efba..efa5aed 100644 (file)
@@ -798,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
       this->error_spec_);
 }
 
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+      {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0},
+                          {4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0},
+                          {9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0},
+                          {3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({1, 0});
+  auto dynamic_slice =
+      builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+  Array2D<float> expected({{96.0, 105.0, 114.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+      {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0},
+                          {4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0},
+                          {9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0},
+                          {3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({0, 1});
+  auto dynamic_slice =
+      builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+  Array2D<float> expected({{105.0}, {105.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+       DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+           DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0},
+                          {4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0},
+                          {9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0},
+                          {3.0, 2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+      {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({0, 1});
+  auto dynamic_slice =
+      builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(0);
+  dot_dnums.add_rhs_contracting_dimensions(1);
+  auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+  Array2D<float> expected({{105.0, 105.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+       DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+           DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0},
+                          {4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0},
+                          {9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0},
+                          {3.0, 2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+      {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({1, 0});
+  auto dynamic_slice =
+      builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(0);
+  dot_dnums.add_rhs_contracting_dimensions(1);
+  auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+  Array2D<float> expected({{96.0}, {105.0}, {114.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+       DISABLED_ON_CPU(DISABLED_ON_GPU(
+           DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(
+      new Array2D<float>({{1.0, 2.0},
+                          {3.0, 4.0},
+                          {5.0, 6.0},
+                          {6.0, 5.0},
+                          {4.0, 3.0},
+                          {2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0},
+                          {4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0},
+                          {9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0},
+                          {3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({0, 1});
+  auto dynamic_slice =
+      builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(0);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+  Array2D<float> expected({{126.0, 129.0, 132.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+       DISABLED_ON_CPU(DISABLED_ON_GPU(
+           DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(
+      new Array2D<float>({{1.0, 2.0},
+                          {3.0, 4.0},
+                          {5.0, 6.0},
+                          {6.0, 5.0},
+                          {4.0, 3.0},
+                          {2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0},
+                          {4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0},
+                          {9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0},
+                          {3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({0, 1});
+  auto dynamic_slice =
+      builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(0);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+  Array2D<float> expected({{129.0}, {129.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+       DISABLED_ON_CPU(DISABLED_ON_GPU(
+           DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+      {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({1, 0});
+  auto dynamic_slice =
+      builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(1);
+  auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+  Array2D<float> expected({{56.0, 168.0, 91.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+       DISABLED_ON_CPU(DISABLED_ON_GPU(
+           DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) {
+  std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+      {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  std::unique_ptr<Array2D<float>> constant_rhs_array(
+      new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+                          {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+                          {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+  // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+  XlaBuilder builder(TestName());
+  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+  auto start_constant = builder.ConstantR1<int32>({1, 0});
+  auto dynamic_slice =
+      builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(1);
+  auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+  Array2D<float> expected({{168.0}, {168.0}});
+  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
 }  // namespace
 }  // namespace xla