From: A. Unique TensorFlower Date: Tue, 15 May 2018 08:22:13 +0000 (-0700) Subject: Reland improve fusion logic of (a dot b) * alpha X-Git-Tag: upstream/v1.9.0_rc1~116^2^2~7 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=af06f858edec0499d13561e9c5a9867a28833c5d;p=platform%2Fupstream%2Ftensorflow.git Reland improve fusion logic of (a dot b) * alpha The previous fusion approach didn't work because a multiplication by a scalar value will be changed into an explicit broadcast. Another issue that is fixed in this CL is retrieving the constant value from the literal. This depends on the PrimitiveType, before we always assumed it to be double. Also when checking ImplementedAsGemm() we should not call it recursively, but instead just the check related to kDot. Finally add an execution test and adjust the fusion logic test. The fix for the issue that caused the revert is that we check earlier that consumer->operand_count() is 2. Also, we fix the call to Get() to pass {} instead of {0}. And we handle an output fusion node in GemmThunk to extract the dimension numbers from the dot operation. PiperOrigin-RevId: 196631031 --- diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7cb7f55..7ee039b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -388,8 +388,10 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service:pattern_matcher", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 2ebb40a..79fca43 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -215,6 +215,25 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { } } +DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { + if (hlo_instruction.opcode() == HloOpcode::kDot) { + return hlo_instruction.dot_dimension_numbers(); + } + CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); + CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); + CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), + HloOpcode::kMultiply); + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = + hlo_instruction.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo_instruction.fused_expression_root()->operand(1); + } + CHECK_EQ(dot->opcode(), HloOpcode::kDot); + + return dot->dot_dimension_numbers(); +} + } // namespace GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, @@ -281,8 +300,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, shape.dimensions(!is_row_major)); }; - const DotDimensionNumbers& dim_nums = - hlo_instruction()->dot_dimension_numbers(); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); const MatrixDescriptor lhs_descriptor = make_descriptor( lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index c5eb721..5d5bef6 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -46,6 +48,15 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { + if (constant->opcode() != HloOpcode::kConstant || + !ShapeUtil::IsScalar(constant->shape())) { + return false; + } + auto type = constant->shape().element_type(); + return type == F16 || type == F32 || type == F64; +} + } // namespace /*static*/ bool GpuInstructionFusion::IsExpensive( @@ -66,34 +77,71 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (producer->opcode() == HloOpcode::kDot) { - if (consumer->opcode() == HloOpcode::kMultiply) { - CHECK_EQ(consumer->operand_count(), 2); - int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); - if (alpha->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalar(alpha->shape())) { + if (consumer->operand_count() == 2 && + (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + int64 other_operand_index = 1 - operand_index; + const HloInstruction* alpha = consumer->operand(other_operand_index); + HloInstruction* op1 = nullptr; + HloInstruction* op2 = nullptr; + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + Match(consumer->fused_expression_root(), + match::Op() + .WithOpcode(HloOpcode::kMultiply) + .WithOperand(0, match::Op(&op1)) + .WithOperand(1, match::Op(&op2)))) { + CHECK(op1 != nullptr && op2 != nullptr); + // If 'consumer' is a fusion node, it should consist of a broadcast of a + // scalar constant fused into a multiply, but nothing more. So one operand + // should be a parameter, and the other should be a broadcast. + if (op1->opcode() != HloOpcode::kParameter) { + std::swap(op1, op2); + } + if (op1->opcode() != HloOpcode::kParameter || + op2->opcode() != HloOpcode::kBroadcast) { + return false; + } + if (IsIEEEFloatingPointScalarConstant(alpha)) { + return true; + } + } else if (consumer->opcode() == HloOpcode::kMultiply) { + // Fuse if 'alpha' is a broadcast of a scalar constant. + if (alpha->opcode() == HloOpcode::kBroadcast && + alpha->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } } } - // Only allow to fuse transpose into an output fusion. + // Only allow fusing transpose or broadcast into an output fusion that is + // implemented as a Gemm call. if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { - if (producer->opcode() != HloOpcode::kTranspose) { - return false; - } - // Check that the transpose is the operand of a dot. + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && + ImplementedAsGemm(*consumer)) { auto producer_operand_index = consumer->operand_index(producer); auto fused_parameter = consumer->fused_parameter(producer_operand_index); const std::vector& fused_parameter_users = fused_parameter->users(); - return (fused_parameter_users.size() == 1 && - fused_parameter_users[0]->opcode() == HloOpcode::kDot); + if (fused_parameter_users.size() != 1) { + return false; + } + if (producer->opcode() == HloOpcode::kTranspose) { + // Check that the transpose is an operand of a dot. + return fused_parameter_users[0]->opcode() == HloOpcode::kDot; + } + if (producer->opcode() == HloOpcode::kBroadcast) { + // Check that the broadcast is a broadcast of a scalar constant into a + // multiply. + return producer->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(producer->operand(0)) && + fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; + } } - // Output fusion is not currently supported on GPUs. + // Other output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; } @@ -134,7 +182,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } - if (producer->opcode() == HloOpcode::kDot) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { return HloInstruction::FusionKind::kOutput; } if (HloOpcode::kFusion == consumer->opcode()) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 6c9a805..760e0e9 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -108,8 +108,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -125,8 +125,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -232,12 +232,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { auto module = tools::Parse(R"( HloModule test_module ENTRY OutputFusion { - constant = f32[] constant(3) + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} p0 = f32[4,3]{1,0} parameter(0) p1 = f32[4,3]{1,0} parameter(1) transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} - dot = f32[4,4]{1,0} dot(p0, transpose) - ROOT mul = f32[4,4] multiply(constant, dot) + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT mul = f32[4,4] multiply(dot, broadcast) })") .ValueOrDie(); @@ -247,10 +248,11 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); EXPECT_THAT( root->fused_expression_root(), - op::Multiply(op::Parameter(), - op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); + op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Broadcast(op::Parameter()))); } // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is @@ -309,5 +311,31 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY NoOutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + d = f32[4,4]{1,0} multiply(dot, dot) + ROOT mul = f32[4,4] multiply(d, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_THAT(root->fused_expression_root(), + op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), + op::Broadcast(op::Parameter()))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 9619903..22e7150 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, !ShapeUtil::HasZeroElements(lhs_shape) && !ShapeUtil::HasZeroElements(rhs_shape); } + +bool DotImplementedAsGemm(const HloInstruction& dot) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + return true; + } + return false; +} } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -69,20 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); - return true; - } + return DotImplementedAsGemm(hlo); } if (hlo.opcode() == HloOpcode::kFusion && @@ -94,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { dot = hlo.fused_expression_root()->operand(1); } if (dot->opcode() == HloOpcode::kDot) { - return ImplementedAsGemm(*dot); + return DotImplementedAsGemm(*dot); } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 83d9029..0d7ba4c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2194,6 +2194,21 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( /*destination_buffer=*/GetAllocationSlice(*inst), inst); } +namespace { +double GetScalarConstantAsDouble(const Literal& literal) { + switch (literal.shape().element_type()) { + case F16: + return static_cast(literal.Get({})); + case F32: + return literal.Get({}); + case F64: + return literal.Get({}); + default: + LOG(FATAL) << "Unsupported type."; + } +} +} // namespace + std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kDot) { @@ -2218,6 +2233,17 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (dot->opcode() != HloOpcode::kDot) { std::swap(dot, alpha); } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + alpha = inst->operand(alpha->parameter_number()); + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + DCHECK(dot->opcode() == HloOpcode::kDot); const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); @@ -2229,13 +2255,13 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( inst->operand(rhs_parameter->parameter_number()); return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*mul), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - alpha->literal().Get({0}), // alpha. + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + GetScalarConstantAsDouble(alpha->literal()), // alpha. inst); }