deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:instruction_fusion",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
],
)
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace gpu {
hlo.opcode() == HloOpcode::kTranspose;
}
+bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
+ if (constant->opcode() != HloOpcode::kConstant ||
+ !ShapeUtil::IsScalar(constant->shape())) {
+ return false;
+ }
+ auto type = constant->shape().element_type();
+ return type == F16 || type == F32 || type == F64;
+}
+
} // namespace
/*static*/ bool GpuInstructionFusion::IsExpensive(
HloInstruction* producer = consumer->mutable_operand(operand_index);
// Check if we can use output fusion for (A @ B) * alpha
- if (producer->opcode() == HloOpcode::kDot) {
- if (consumer->opcode() == HloOpcode::kMultiply) {
- CHECK_EQ(consumer->operand_count(), 2);
- int64 other_operand_index = 1 - operand_index;
- const HloInstruction* alpha = consumer->operand(other_operand_index);
- if (alpha->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsScalar(alpha->shape())) {
+ if (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
+ int64 other_operand_index = 1 - operand_index;
+ const HloInstruction* alpha = consumer->operand(other_operand_index);
+ HloInstruction* op1 = nullptr;
+ HloInstruction* op2 = nullptr;
+ if (consumer->opcode() == HloOpcode::kFusion &&
+ consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
+ consumer->operand_count() == 2 &&
+ Match(consumer->fused_expression_root(),
+ match::Op()
+ .WithOpcode(HloOpcode::kMultiply)
+ .WithOperand(0, match::Op(&op1))
+ .WithOperand(1, match::Op(&op2)))) {
+ CHECK(op1 != nullptr && op2 != nullptr);
+ // If 'consumer' is a fusion node, it should consist of a broadcast of a
+ // scalar constant fused into a multiply, but nothing more. So one operand
+ // should be a parameter, and the other should be a broadcast.
+ if (op1->opcode() != HloOpcode::kParameter) {
+ std::swap(op1, op2);
+ }
+ if (op1->opcode() != HloOpcode::kParameter ||
+ op2->opcode() != HloOpcode::kBroadcast) {
+ return false;
+ }
+ if (IsIEEEFloatingPointScalarConstant(alpha)) {
+ return true;
+ }
+ } else if (consumer->opcode() == HloOpcode::kMultiply) {
+ // Fuse if 'alpha' is a broadcast of a scalar constant.
+ if (alpha->opcode() == HloOpcode::kBroadcast &&
+ alpha->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
return true;
}
}
}
- // Only allow to fuse transpose into an output fusion.
+ // Only allow fusing transpose or broadcast into an output fusion that is
+ // implemented as a Gemm call.
if (consumer->opcode() == HloOpcode::kFusion &&
- consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) {
- if (producer->opcode() != HloOpcode::kTranspose) {
- return false;
- }
- // Check that the transpose is the operand of a dot.
+ consumer->fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ ImplementedAsGemm(*consumer)) {
auto producer_operand_index = consumer->operand_index(producer);
auto fused_parameter = consumer->fused_parameter(producer_operand_index);
const std::vector<HloInstruction*>& fused_parameter_users =
fused_parameter->users();
- return (fused_parameter_users.size() == 1 &&
- fused_parameter_users[0]->opcode() == HloOpcode::kDot);
+ if (fused_parameter_users.size() != 1) {
+ return false;
+ }
+ if (producer->opcode() == HloOpcode::kTranspose) {
+ // Check that the transpose is an operand of a dot.
+ return fused_parameter_users[0]->opcode() == HloOpcode::kDot;
+ }
+ if (producer->opcode() == HloOpcode::kBroadcast) {
+ // Check that the broadcast is a broadcast of a scalar constant into a
+ // multiply.
+ return producer->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(producer->operand(0)) &&
+ fused_parameter_users[0]->opcode() == HloOpcode::kMultiply;
+ }
}
- // Output fusion is not currently supported on GPUs.
+ // Other output fusions are not currently supported on GPUs.
if (producer->opcode() == HloOpcode::kFusion) {
return false;
}
if (IsReductionToVector(*consumer)) {
return HloInstruction::FusionKind::kInput;
}
- if (producer->opcode() == HloOpcode::kDot) {
+ if (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
return HloInstruction::FusionKind::kOutput;
}
if (HloOpcode::kFusion == consumer->opcode()) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
auto module = tools::Parse(R"(
HloModule test_module
ENTRY OutputFusion {
- constant = f32[] constant(3)
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0}
- dot = f32[4,4]{1,0} dot(p0, transpose)
- ROOT mul = f32[4,4] multiply(constant, dot)
+ dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT mul = f32[4,4] multiply(dot, broadcast)
})")
.ValueOrDie();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput);
EXPECT_THAT(
root->fused_expression_root(),
- op::Multiply(op::Parameter(),
- op::Dot(op::Parameter(), op::Transpose(op::Parameter()))));
+ op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
+ op::Broadcast(op::Parameter())));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
.ValueOrDie());
}
+TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY NoOutputFusion {
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[3,4]{1,0} parameter(1)
+ dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ d = f32[4,4]{1,0} multiply(dot, dot)
+ ROOT mul = f32[4,4] multiply(d, broadcast)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Multiply(op::Multiply(op::Parameter(), op::Parameter()),
+ op::Broadcast(op::Parameter())));
+}
+
} // namespace gpu
} // namespace xla
!ShapeUtil::HasZeroElements(lhs_shape) &&
!ShapeUtil::HasZeroElements(rhs_shape);
}
+
+bool DotImplementedAsGemm(const HloInstruction& dot) {
+ CHECK_EQ(dot.opcode(), HloOpcode::kDot);
+ const Shape& lhs_shape = dot.operand(0)->shape();
+ const Shape& rhs_shape = dot.operand(1)->shape();
+
+ // If gemm can accept the operand shapes, use it rather than a custom
+ // kernel.
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) {
+ // The size of the reduction dimension should match. The shape inference
+ // guarantees this invariant, so the check here is for programming
+ // errors.
+ const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
+ CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
+ rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
+ return true;
+ }
+ return false;
+}
} // namespace
bool ImplementedAsGemm(const HloInstruction& hlo) {
// For certain types of Dot, we can call pre-canned BLAS gemm.
if (hlo.opcode() == HloOpcode::kDot) {
- const Shape& lhs_shape = hlo.operand(0)->shape();
- const Shape& rhs_shape = hlo.operand(1)->shape();
-
- // If gemm can accept the operand shapes, use it rather than a custom
- // kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
- // The size of the reduction dimension should match. The shape inference
- // guarantees this invariant, so the check here is for programming
- // errors.
- const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
- CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
- rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
- return true;
- }
+ return DotImplementedAsGemm(hlo);
}
if (hlo.opcode() == HloOpcode::kFusion) {
dot = hlo.fused_expression_root()->operand(1);
}
if (dot->opcode() == HloOpcode::kDot) {
- return ImplementedAsGemm(*dot);
+ return DotImplementedAsGemm(*dot);
}
}
/*destination_buffer=*/GetAllocationSlice(*inst), inst);
}
+namespace {
+double GetScalarConstantAsDouble(const Literal& literal) {
+ switch (literal.shape().element_type()) {
+ case F16:
+ return static_cast<double>(literal.Get<Eigen::half>({0}));
+ case F32:
+ return literal.Get<float>({0});
+ case F64:
+ return literal.Get<double>({0});
+ default:
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+} // namespace
+
std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* inst) {
if (inst->opcode() == HloOpcode::kDot) {
if (dot->opcode() != HloOpcode::kDot) {
std::swap(dot, alpha);
}
+ if (alpha->opcode() == HloOpcode::kBroadcast) {
+ alpha = alpha->operand(0);
+ }
+ alpha = inst->operand(alpha->parameter_number());
+ // TODO(b/74185543): Remove the following if block once we support fusion
+ // with a non-constant as well. Then we will just always use the constant
+ // on the device.
+ if (alpha->opcode() == HloOpcode::kCopy) {
+ alpha = alpha->operand(0);
+ }
+
DCHECK(dot->opcode() == HloOpcode::kDot);
const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
inst->operand(rhs_parameter->parameter_number());
return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*mul), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- alpha->literal().Get<double>({0}), // alpha.
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ GetScalarConstantAsDouble(alpha->literal()), // alpha.
inst);
}