From 6304da208116ed00ad4ee776787dfa6fe8256f4f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 7 May 2018 03:00:34 -0700 Subject: [PATCH] Improve fusion logic of (a dot b) * alpha The previous approach didn't work because a multiplication by a scalar value will be changed into an explicit broadcast. Another issue that is fixed in this CL is retrieving the constant value from the literal. This depends on the PrimitiveType, before we always assumed it to be double. Also when checking ImplementedAsGemm() we should not call it recursively, but instead just the check related to kDot. Finally add an execution test and adjust the fusion logic test. PiperOrigin-RevId: 195638795 --- tensorflow/compiler/xla/service/gpu/BUILD | 2 + .../compiler/xla/service/gpu/instruction_fusion.cc | 84 +++++++++++++++++----- .../xla/service/gpu/instruction_fusion_test.cc | 46 +++++++++--- .../compiler/xla/service/gpu/ir_emission_utils.cc | 36 ++++++---- .../xla/service/gpu/ir_emitter_unnested.cc | 40 +++++++++-- 5 files changed, 160 insertions(+), 48 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7cb7f55..7ee039b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -388,8 +388,10 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service:pattern_matcher", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index c5eb721..04c7cc3 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -46,6 +48,15 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { + if (constant->opcode() != HloOpcode::kConstant || + !ShapeUtil::IsScalar(constant->shape())) { + return false; + } + auto type = constant->shape().element_type(); + return type == F16 || type == F32 || type == F64; +} + } // namespace /*static*/ bool GpuInstructionFusion::IsExpensive( @@ -66,34 +77,71 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (producer->opcode() == HloOpcode::kDot) { - if (consumer->opcode() == HloOpcode::kMultiply) { - CHECK_EQ(consumer->operand_count(), 2); - int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); - if (alpha->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalar(alpha->shape())) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { + int64 other_operand_index = 1 - operand_index; + const HloInstruction* alpha = consumer->operand(other_operand_index); + HloInstruction* op1 = nullptr; + HloInstruction* op2 = nullptr; + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + consumer->operand_count() == 2 && + Match(consumer->fused_expression_root(), + match::Op() + .WithOpcode(HloOpcode::kMultiply) + .WithOperand(0, match::Op(&op1)) + .WithOperand(1, match::Op(&op2)))) { + CHECK(op1 != nullptr && op2 != nullptr); + // If 'consumer' is a fusion node, it should consist of a broadcast of a + // scalar constant fused into a multiply, but nothing more. So one operand + // should be a parameter, and the other should be a broadcast. + if (op1->opcode() != HloOpcode::kParameter) { + std::swap(op1, op2); + } + if (op1->opcode() != HloOpcode::kParameter || + op2->opcode() != HloOpcode::kBroadcast) { + return false; + } + if (IsIEEEFloatingPointScalarConstant(alpha)) { + return true; + } + } else if (consumer->opcode() == HloOpcode::kMultiply) { + // Fuse if 'alpha' is a broadcast of a scalar constant. + if (alpha->opcode() == HloOpcode::kBroadcast && + alpha->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } } } - // Only allow to fuse transpose into an output fusion. + // Only allow fusing transpose or broadcast into an output fusion that is + // implemented as a Gemm call. if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { - if (producer->opcode() != HloOpcode::kTranspose) { - return false; - } - // Check that the transpose is the operand of a dot. + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && + ImplementedAsGemm(*consumer)) { auto producer_operand_index = consumer->operand_index(producer); auto fused_parameter = consumer->fused_parameter(producer_operand_index); const std::vector& fused_parameter_users = fused_parameter->users(); - return (fused_parameter_users.size() == 1 && - fused_parameter_users[0]->opcode() == HloOpcode::kDot); + if (fused_parameter_users.size() != 1) { + return false; + } + if (producer->opcode() == HloOpcode::kTranspose) { + // Check that the transpose is an operand of a dot. + return fused_parameter_users[0]->opcode() == HloOpcode::kDot; + } + if (producer->opcode() == HloOpcode::kBroadcast) { + // Check that the broadcast is a broadcast of a scalar constant into a + // multiply. + return producer->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(producer->operand(0)) && + fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; + } } - // Output fusion is not currently supported on GPUs. + // Other output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; } @@ -134,7 +182,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } - if (producer->opcode() == HloOpcode::kDot) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { return HloInstruction::FusionKind::kOutput; } if (HloOpcode::kFusion == consumer->opcode()) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 6c9a805..760e0e9 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -108,8 +108,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -125,8 +125,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -232,12 +232,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { auto module = tools::Parse(R"( HloModule test_module ENTRY OutputFusion { - constant = f32[] constant(3) + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} p0 = f32[4,3]{1,0} parameter(0) p1 = f32[4,3]{1,0} parameter(1) transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} - dot = f32[4,4]{1,0} dot(p0, transpose) - ROOT mul = f32[4,4] multiply(constant, dot) + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT mul = f32[4,4] multiply(dot, broadcast) })") .ValueOrDie(); @@ -247,10 +248,11 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); EXPECT_THAT( root->fused_expression_root(), - op::Multiply(op::Parameter(), - op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); + op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Broadcast(op::Parameter()))); } // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is @@ -309,5 +311,31 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY NoOutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + d = f32[4,4]{1,0} multiply(dot, dot) + ROOT mul = f32[4,4] multiply(d, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_THAT(root->fused_expression_root(), + op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), + op::Broadcast(op::Parameter()))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 7773457..8ab7fe9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, !ShapeUtil::HasZeroElements(lhs_shape) && !ShapeUtil::HasZeroElements(rhs_shape); } + +bool DotImplementedAsGemm(const HloInstruction& dot) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + return true; + } + return false; +} } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -69,20 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); - return true; - } + return DotImplementedAsGemm(hlo); } if (hlo.opcode() == HloOpcode::kFusion) { @@ -98,7 +104,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { dot = hlo.fused_expression_root()->operand(1); } if (dot->opcode() == HloOpcode::kDot) { - return ImplementedAsGemm(*dot); + return DotImplementedAsGemm(*dot); } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 83d9029..dcaedcb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2194,6 +2194,21 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( /*destination_buffer=*/GetAllocationSlice(*inst), inst); } +namespace { +double GetScalarConstantAsDouble(const Literal& literal) { + switch (literal.shape().element_type()) { + case F16: + return static_cast(literal.Get({0})); + case F32: + return literal.Get({0}); + case F64: + return literal.Get({0}); + default: + LOG(FATAL) << "Unsupported type."; + } +} +} // namespace + std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kDot) { @@ -2218,6 +2233,17 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (dot->opcode() != HloOpcode::kDot) { std::swap(dot, alpha); } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + alpha = inst->operand(alpha->parameter_number()); + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + DCHECK(dot->opcode() == HloOpcode::kDot); const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); @@ -2229,13 +2255,13 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( inst->operand(rhs_parameter->parameter_number()); return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*mul), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - alpha->literal().Get({0}), // alpha. + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + GetScalarConstantAsDouble(alpha->literal()), // alpha. inst); } -- 2.7.4