const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
+ StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
return add_result;
}
+StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
+ HloInstruction* dot) {
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+ if (dnums.lhs_contracting_dimensions_size() != 1 ||
+ dnums.rhs_contracting_dimensions_size() != 1 ||
+ dnums.lhs_batch_dimensions_size() != 0 ||
+ dnums.rhs_batch_dimensions_size() != 0 ||
+ dot->shape().dimensions_size() != 2) { // dot output 2D
+ VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
+ return nullptr;
+ }
+
+ // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
+ // Currently a Gather is a DynamicSlice.
+ auto is_dynamic_slice_constant_combination =
+ [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
+ // First operand is a DynamicSlice(Constant).
+ if (a->opcode() != HloOpcode::kDynamicSlice) {
+ return false;
+ }
+ auto* dynamic_slice_op = a->operand(0);
+ if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
+ return false;
+ }
+ // Second operand is a Constant.
+ if (b->opcode() != HloOpcode::kConstant) {
+ return false;
+ }
+ // The DynamicSlice output is a vector.
+ const Shape& dynamic_slice_shape = a->shape();
+ if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
+ return false;
+ }
+ // Constant size is the same before and after slice in the contracting
+ // dimension, otherwise we either must precompute for all possible slice
+ // indices or dot is invalid.
+ const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
+ if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
+ dynamic_slice_shape.dimensions(a_contracting_dimension)) {
+ return false;
+ }
+ return true;
+ };
+
+ HloInstruction* lhs = dot->mutable_operand(0);
+ HloInstruction* rhs = dot->mutable_operand(1);
+ int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
+ int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
+
+ if (!is_dynamic_slice_constant_combination(
+ lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
+ !is_dynamic_slice_constant_combination(
+ rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
+ VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
+ "dot(ctB, DS(ctA)), where the two constants have equal "
+ "contracting dimensions.";
+ return nullptr;
+ }
+
+ // LHS is DynamicSlice:
+ // input: dot(DS(ctA), ctB))
+ // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
+ // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
+ // output: DS(dot(ctA, ctB))
+ // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
+
+ // RHS is DynamicSlice:
+ // input: dot(ctA, DS(ctB))
+ // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
+ // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
+ // output: DS(dot(ctA, ctB))
+ // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
+
+ bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
+
+ // ctA:
+ HloInstruction* left_operand =
+ lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
+ // ctB:
+ HloInstruction* right_operand =
+ lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
+ // Build ctA x ctB.
+ const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
+ const int n =
+ right_operand->shape().dimensions(1 - rhs_contracting_dimension);
+ auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
+ auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
+ memoized_shape, left_operand, right_operand, dnums));
+ // Get pair {start, 0} or {0, start}.
+ HloInstruction* original_start_indices =
+ lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
+ // Position of start:
+ int index_of_non_zero_start = lhs_is_dynamic_slice
+ ? 1 - lhs_contracting_dimension
+ : 1 - rhs_contracting_dimension;
+ // Position of zero:
+ int index_of_zero_start = 1 - index_of_non_zero_start;
+
+ // Slice out start and 0 components and reorder if necessary.
+ auto indices_type = original_start_indices->shape().element_type();
+ Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
+ Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
+ HloInstruction* non_zero_start =
+ computation_->AddInstruction(HloInstruction::CreateSlice(
+ s_shape, original_start_indices, {index_of_non_zero_start},
+ {index_of_non_zero_start + 1}, {1}));
+ HloInstruction* zero_start =
+ computation_->AddInstruction(HloInstruction::CreateSlice(
+ s_shape, original_start_indices, {index_of_zero_start},
+ {index_of_zero_start + 1}, {1}));
+ HloInstruction* new_start_indices =
+ lhs_is_dynamic_slice
+ ? computation_->AddInstruction(HloInstruction::CreateConcatenate(
+ d_shape, {non_zero_start, zero_start}, 0))
+ : computation_->AddInstruction(HloInstruction::CreateConcatenate(
+ d_shape, {zero_start, non_zero_start}, 0));
+
+ // Build DynamicSlice(ctA x ctB).
+ const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
+ const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
+ auto* memoized_lookup =
+ computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
+ dot->shape(), memoized_inst, new_start_indices,
+ {new_slice_m, new_slice_n}));
+
+ return memoized_lookup;
+}
+
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
HloInstruction *lhs, *rhs;
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
return ReplaceInstruction(dot, dot_of_concat_optimized);
}
+ // Simplify dot(ConstA, Gather(Index, ConstB)) to:
+ // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
+ // batched version of dot.
+ TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
+ OptimizeDotOfGather(dot));
+ if (dot_of_gather_optimized) {
+ VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
+ "gather(i, dot*(constA, constB))";
+ return ReplaceInstruction(dot, dot_of_gather_optimized);
+ }
+
if (enable_dot_strength_reduction_ && !is_layout_sensitive_) {
TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
HandleDotStrengthReduction(dot));
INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation,
DotOfConcatSimplificationTest,
::testing::ValuesIn(kDotOfConcatTestSpecs));
+
+struct DotOfGatherTestSpec {
+ int64 m;
+ int64 k;
+ int64 n;
+ int s; // start index for dynamic slice on the non-contracting dimension
+ int64 lcd; // left contracting dimension
+ int64 rcd; // right contracting dimension
+ bool neg; // is negative testcase
+};
+
+class DotOfGatherSimplificationTest
+ : public HloVerifiedTestBase,
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
+
+// input: dot(DS(ctA), ctB))
+// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
+// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
+// output: DS(dot(ctA, ctB))
+// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
+TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
+ HloComputation::Builder builder(TestName());
+
+ DotOfGatherTestSpec spec = GetParam();
+
+ ASSERT_LE(spec.s, spec.m);
+
+ // For negative tests, increase k of the dynamic slice argument to prevent the
+ // optimization (constants ctA, ctB must have equal contracting dimensions).
+ int64 k_increase = spec.neg ? 5 : 0;
+ int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
+ int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
+ auto* lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
+ /*cols=*/lhs_cols)));
+
+ int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
+ int32 start_col = (spec.lcd == 0) ? spec.s : 0;
+ const auto start_indices =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<int32>({start_row, start_col})));
+ int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
+ int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
+ Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
+ auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
+ ds_shape, lhs, start_indices, {slice_row_size, slice_col_size}));
+
+ int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
+ int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
+ auto* rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
+ /*cols=*/rhs_cols)));
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
+ dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
+
+ int64 dot_row_size = 1;
+ int64 dot_col_size = spec.n;
+ Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
+ builder.AddInstruction(
+ HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
+ ASSERT_TRUE(run_successful);
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
+
+ if (spec.neg) {
+ EXPECT_NE(computation->root_instruction()->opcode(),
+ HloOpcode::kDynamicSlice);
+ } else {
+ EXPECT_THAT(computation->root_instruction(),
+ op::DynamicSlice(op::Dot(op::Constant(), op::Constant()),
+ op::Concatenate()));
+ }
+}
+
+// input: dot(ctA, DS(ctB))
+// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
+// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
+// output: DS(dot(ctA, ctB))
+// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
+TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
+ HloComputation::Builder builder(TestName());
+
+ DotOfGatherTestSpec spec = GetParam();
+
+ ASSERT_LE(spec.s, spec.n);
+
+ int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
+ int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
+ auto* lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
+ /*cols=*/lhs_cols)));
+
+ // For negative tests increase k of the dynamic slice argument to prevent the
+ // optimization
+ int64 k_increase = spec.neg ? 5 : 0;
+ int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
+ int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
+ auto* rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
+ /*cols=*/rhs_cols)));
+
+ int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
+ int32 start_col = (spec.rcd == 0) ? spec.s : 0;
+ const auto start_indices =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<int32>({start_row, start_col})));
+ int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
+ int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
+ Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
+ auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
+ ds_shape, rhs, start_indices, {slice_row_size, slice_col_size}));
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
+ dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
+
+ int64 dot_row_size = spec.m;
+ int64 dot_col_size = 1;
+ Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
+ builder.AddInstruction(
+ HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
+ ASSERT_TRUE(run_successful);
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
+
+ if (spec.neg) {
+ EXPECT_NE(computation->root_instruction()->opcode(),
+ HloOpcode::kDynamicSlice);
+ } else {
+ EXPECT_THAT(computation->root_instruction(),
+ op::DynamicSlice(op::Dot(op::Constant(), op::Constant()),
+ op::Concatenate()));
+ }
+}
+
+std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
+ std::vector<DotOfGatherTestSpec> positives = {
+ // "Classical dot", i.e. matrix multiply:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
+ /*neg=*/false},
+ // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
+ // dot(ct, ct) before DotOfGather optimization kicks in.
+ // Contract on rows:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
+ /*neg=*/false},
+ // Reverse matrix multiply:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
+ /*neg=*/false},
+ // Contract on columns:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
+ /*neg=*/false},
+ };
+ std::vector<DotOfGatherTestSpec> all;
+ for (int i = 0; i < positives.size(); i++) {
+ DotOfGatherTestSpec positive_test = positives[i];
+ all.push_back(positive_test);
+ DotOfGatherTestSpec negative_test = positive_test;
+ negative_test.neg = true;
+ all.push_back(negative_test);
+ }
+ return all;
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
+ ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
+
} // namespace
} // namespace xla
this->error_spec_);
}
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{96.0, 105.0, 114.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{105.0}, {105.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{105.0, 105.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{96.0}, {105.0}, {114.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {6.0, 5.0},
+ {4.0, 3.0},
+ {2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{126.0, 129.0, 132.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {6.0, 5.0},
+ {4.0, 3.0},
+ {2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{129.0}, {129.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{56.0, 168.0, 91.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{168.0}, {168.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
} // namespace
} // namespace xla