enable_dot_strength_reduction_(enable_dot_strength_reduction),
enable_conv_simplification_(enable_conv_simplification) {}
+ // Transforms Dots where at least one input is a vector or has a degenerate
+ // dimension and converts it into a multiply and reduce. This should enable
+ // more fusion than leaving the nodes as Dot operations.
+ StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot);
+
+ // Reshapes an instruction to rank 1 if it is not already rank 1.
+ HloInstruction* Flatten(HloInstruction* hlo) {
+ if (ShapeUtil::Rank(hlo->shape()) == 1) {
+ return hlo;
+ }
+ return computation_->AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(hlo->shape().element_type(),
+ {ShapeUtil::ElementsIn(hlo->shape())}),
+ hlo));
+ }
+
+ // Helper method to perform and add reduction in a single dimension.
+ HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
+ HloInstruction* zero = computation_->AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+ HloComputation* AddReduce_computation = CreateScalarBinaryComputation(
+ computation_->parent(), F32, HloOpcode::kAdd);
+ Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
+ return computation_->AddInstruction(HloInstruction::CreateReduce(
+ shape, hlo, zero, {dim}, AddReduce_computation));
+ }
+
// Convenience method for replacing an instruction with a bitcast.
void ReplaceWithBitcast(HloInstruction* instruction);
return Status::OK();
}
+StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
+ HloInstruction* dot) {
+ HloInstruction* lhs = dot->mutable_operand(0);
+ HloInstruction* rhs = dot->mutable_operand(1);
+ int64 lhs_collapsing_dim =
+ dot->dot_dimension_numbers().lhs_contracting_dimensions(0);
+ if (lhs->IsRank2Transpose()) {
+ lhs = lhs->mutable_operand(0);
+ lhs_collapsing_dim = 1 - lhs_collapsing_dim;
+ }
+ const int64 lhs_kept_dim = 1 - lhs_collapsing_dim;
+
+ int64 rhs_collapsing_dim =
+ dot->dot_dimension_numbers().rhs_contracting_dimensions(0);
+ if (rhs->IsRank2Transpose()) {
+ rhs = rhs->mutable_operand(0);
+ rhs_collapsing_dim = 1 - rhs_collapsing_dim;
+ }
+ const int64 rhs_kept_dim = 1 - rhs_collapsing_dim;
+
+ auto reshape_if_necessary = [&](HloInstruction* hlo) {
+ if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
+ return hlo;
+ }
+ return computation_->AddInstruction(
+ HloInstruction::CreateReshape(dot->shape(), hlo));
+ };
+
+ auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
+ int64 dim) {
+ return computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(shape, hlo, {dim}));
+ };
+
+ auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) {
+ return computation_->AddInstruction(HloInstruction::CreateBinary(
+ local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs));
+ };
+
+ // Strength reduce dot(a[K] , b[K]) =
+ // reshape(result.shape,
+ // reduce_sum(multiply(a, b), {0}))
+ if (ShapeUtil::Rank(rhs->shape()) == 1 &&
+ ShapeUtil::Rank(lhs->shape()) == 1) {
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(dot, reshape_if_necessary(AddReduce(
+ multiply(Flatten(lhs), Flatten(rhs)), 0))));
+ return true;
+ }
+
+ if (ShapeUtil::IsEffectiveScalar(rhs->shape()) &&
+ ShapeUtil::IsEffectiveScalar(lhs->shape())) {
+ TF_RETURN_IF_ERROR(ReplaceInstruction(
+ dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs)))));
+ return true;
+ }
+
+ // Simplify outer product into multiply with implicit broadcasting.
+ //
+ // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
+ if (ShapeUtil::Rank(rhs->shape()) == 2 &&
+ rhs->shape().dimensions(rhs_collapsing_dim) == 1) {
+ TF_RETURN_IF_ERROR(ReplaceInstruction(
+ dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0),
+ broadcast_to_dim(Flatten(rhs), dot->shape(), 1))));
+ return true;
+ }
+
+ // Strength reduce dot(a[1, K], b) =
+ // reshape(result.shape,
+ // reduce_sum(
+ // multiply(broadcast(reshape(a, [K]), {0}), b),
+ // {0})
+ // )
+ // )
+ if (ShapeUtil::Rank(lhs->shape()) == 1 ||
+ (ShapeUtil::Rank(lhs->shape()) == 2 &&
+ lhs->shape().dimensions(lhs_kept_dim) == 1)) {
+ if (ShapeUtil::Rank(rhs->shape()) == 1) {
+ TF_RETURN_IF_ERROR(ReplaceInstruction(
+ dot,
+ reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0))));
+ return true;
+ }
+ TF_RETURN_IF_ERROR(ReplaceInstruction(
+ dot, reshape_if_necessary(
+ AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
+ rhs_collapsing_dim),
+ rhs),
+ rhs_collapsing_dim))));
+ return true;
+ }
+
+ // Strength reduce dot(a, b[K, 1]) =
+ // reshape(result.shape,
+ // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
+ // )
+ if (ShapeUtil::Rank(rhs->shape()) == 1 ||
+ (ShapeUtil::Rank(rhs->shape()) == 2 &&
+ rhs->shape().dimensions(rhs_kept_dim) == 1)) {
+ TF_RETURN_IF_ERROR(ReplaceInstruction(
+ dot, reshape_if_necessary(AddReduce(
+ multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
+ lhs_collapsing_dim)),
+ lhs_collapsing_dim))));
+ return true;
+ }
+ return false;
+}
+
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
auto lhs = dot->mutable_operand(0);
auto rhs = dot->mutable_operand(1);
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}
+ if (enable_dot_strength_reduction_ && !is_layout_sensitive_) {
+ TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
+ HandleDotStrengthReduction(dot));
+ if (did_strength_reduction) {
+ return Status::OK();
+ }
+ }
+
// Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
DotDimensionNumbers dot_dimension_numbers;
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
- if (!enable_dot_strength_reduction_) {
- return Status::OK();
- }
-
- // Simplify outer product into multiply with implicit broadcasting.
- //
- // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
- if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) {
- return ReplaceWithNewInstruction(
- dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
- lhs, rhs));
- }
-
- // The following graph transformations take Dots where at least one input is a
- // vector or has a degenerate dimension and converts it into a multiply and
- // reduce. This should enable more fusion than leaving the nodes as Dot
- // operations.
-
- // Strength reduce dot(a[K] , b[K]) =
- // reshape(result.shape,
- // reduce_sum(multiply(a, b), {0}))
- if (ShapeUtil::Rank(rhs->shape()) == 1 &&
- ShapeUtil::Rank(lhs->shape()) == 1) {
- auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
- rhs->shape(), HloOpcode::kMultiply, lhs, rhs));
- HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
- computation_->parent(), F32, HloOpcode::kAdd);
- auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
- auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
- ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
- {0}, add_reduce_computation));
- return ReplaceWithNewInstruction(
- dot, HloInstruction::CreateReshape(dot->shape(), reduce));
- }
-
- // Strength reduce dot(a[1, K], b) =
- // reshape(result.shape,
- // reduce_sum(
- // multiply(broadcast(reshape(a, [K]), {0}), b),
- // {0})
- // )
- // )
- if (ShapeUtil::Rank(lhs->shape()) == 1 ||
- (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) {
- auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(lhs->shape().element_type(),
- {ShapeUtil::ElementsIn(lhs->shape())}),
- lhs));
- HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
- computation_->parent(), F32, HloOpcode::kAdd);
- auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
- HloInstruction* reduce;
- if (ShapeUtil::Rank(rhs->shape()) == 1) {
- auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
- rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
- reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
- ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
- {0}, add_reduce_computation));
- } else {
- new_lhs = computation_->AddInstruction(
- HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0}));
- auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
- rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
-
- reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
- ShapeUtil::MakeShape(dot->shape().element_type(),
- {rhs->shape().dimensions(1)}),
- multiply, zero, {0}, add_reduce_computation));
- }
- return ReplaceWithNewInstruction(
- dot, HloInstruction::CreateReshape(dot->shape(), reduce));
- }
-
- // Strength reduce dot(a, b[K, 1]) =
- // reshape(result.shape,
- // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
- // )
- if (ShapeUtil::Rank(rhs->shape()) == 1 ||
- (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) {
- auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(rhs->shape().element_type(),
- {ShapeUtil::ElementsIn(rhs->shape())}),
- rhs));
- new_rhs = computation_->AddInstruction(
- HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1}));
- auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
- lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs));
- HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
- computation_->parent(), F32, HloOpcode::kAdd);
- auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
- auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
- ShapeUtil::MakeShape(dot->shape().element_type(),
- {lhs->shape().dimensions(0)}),
- multiply, zero, {1}, add_reduce_computation));
- return ReplaceWithNewInstruction(
- dot, HloInstruction::CreateReshape(dot->shape(), reduce));
- }
return Status::OK();
}
op::DynamicSlice(op::Parameter(), op::Parameter()));
}
+class DotStrengthReductionTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<
+ ::testing::tuple<int, int, int, bool, bool>> {};
+TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
+ int m, k, n;
+ bool transpose_lhs, transpose_rhs;
+ std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+
+ Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
+ Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
+ Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+ HloComputation::Builder builder(TestName());
+
+ auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
+ if (transpose_lhs) {
+ lhs = builder.AddInstruction(
+ HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
+ }
+ auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
+ if (transpose_rhs) {
+ rhs = builder.AddInstruction(
+ HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
+ }
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ builder.AddInstruction(
+ HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
+ const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
+ const bool computation_should_be_modified =
+ dot_should_be_transformed || (transpose_lhs && transpose_rhs);
+ EXPECT_EQ(changed, computation_should_be_modified);
+ bool has_no_dot = true;
+ for (const auto& hlo : computation->instructions()) {
+ if (hlo->opcode() == HloOpcode::kDot) {
+ has_no_dot = false;
+ break;
+ }
+ }
+ EXPECT_EQ(has_no_dot, dot_should_be_transformed);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
+ ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
+ ::testing::Values(1, 2), ::testing::Bool(),
+ ::testing::Bool()));
+
} // namespace
} // namespace xla