From: Yuanzhong Xu Date: Mon, 12 Feb 2018 19:26:22 +0000 (-0800) Subject: [XLA] An HLO pass that folds BF16 F32 conversions: if an HLO already supports BF16... X-Git-Tag: upstream/v1.7.0~31^2~788 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fabf6ddede109bbf18115718224449c314bcf92a;p=platform%2Fupstream%2Ftensorflow.git [XLA] An HLO pass that folds BF16 F32 conversions: if an HLO already supports BF16 input/output, conversions before/after it will be removed and the HLO's input/output types will be converted to BF16. Also updates HloVerifier to allow mixed precision if requested. If an HLO has both both F32 and BF16 inputs, ShapeInference will use F32 as the output type. PiperOrigin-RevId: 185407143 --- diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 93cc5ab1a9..9f5f2f96b7 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -43,6 +43,81 @@ filegroup( ]), ) +cc_library( + name = "bfloat16_support", + srcs = ["bfloat16_support.cc"], + hdrs = ["bfloat16_support.h"], + deps = [ + ":hlo", + ], +) + +cc_library( + name = "bfloat16_conversion_folding", + srcs = ["bfloat16_conversion_folding.cc"], + hdrs = ["bfloat16_conversion_folding.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_conversion_folding_test", + srcs = ["bfloat16_conversion_folding_test.cc"], + deps = [ + ":bfloat16_conversion_folding", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "bfloat16_normalization", + srcs = ["bfloat16_normalization.cc"], + hdrs = ["bfloat16_normalization.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_normalization_test", + srcs = ["bfloat16_normalization_test.cc"], + deps = [ + ":bfloat16_normalization", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc new file mode 100644 index 0000000000..cde990e176 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -0,0 +1,184 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16ConversionFoldingVisitor( + HloComputation* computation, const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16 + // conversion as output, and folds them to the HLO itself if feasible. + Status TryFoldBF16Conversions(HloInstruction* hlo); + + // Folds the F32 -> BF16 conversions from the HLO's output. + // + // Precondition: all of the HLO's users are F32 -> BF16 conversions. + Status FoldOutputConversions(HloInstruction* hlo); + + // Folds the BF16 -> F32 conversion operand to the HLO. + // + // Precondition: the operand is a F32 -> BF16 conversion. + Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( + HloInstruction* hlo) { + std::vector materialized_users = hlo->users(); + hlo->mutable_shape()->set_element_type(BF16); + for (auto user : materialized_users) { + CHECK_EQ(user->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + changed_ = true; + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( + HloInstruction* hlo, int64 operand_index) { + // The operand is a convert from BF16 to F32. + auto operand = hlo->mutable_operand(operand_index); + CHECK_EQ(operand->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR( + hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0))); + changed_ = true; + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( + HloInstruction* hlo) { + std::vector bf16_to_f32_operands; + bool has_other_f32_operands = false; + for (int64 i = 0; i < hlo->operands().size(); ++i) { + auto operand = hlo->operand(i); + if (operand->shape().element_type() == F32) { + if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->shape().element_type() == BF16 && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + // Operand is a convert from BF16 to F32 and we support BF16 input + // directly in the current HLO at the operand index. + bf16_to_f32_operands.push_back(i); + } else { + has_other_f32_operands = true; + } + continue; + } + } + + bool fold_output_conversion = hlo->user_count() > 0 && + hlo->shape().element_type() == F32 && + bfloat16_support_->SupportsBF16Output(*hlo) && + hlo != computation_->root_instruction(); + if (fold_output_conversion) { + for (auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + // We should not change the output type if any user is not a conversion + // from F32 to BF16. + fold_output_conversion = false; + break; + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + if (has_other_f32_operands || + (!fold_output_conversion && hlo->shape().element_type() == F32)) { + // Some of the operands/output will remain F32, but we cannot use mixed + // precisions, so we cannot do anything here. + return Status::OK(); + } + } + + if (fold_output_conversion) { + TF_RETURN_IF_ERROR(FoldOutputConversions(hlo)); + } + + for (int64 i : bf16_to_f32_operands) { + TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i)); + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { + // Do not fold BF16 conversions for instructions related to tuples, entry and + // exit of a computation, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + if (hlo == computation_->root_instruction() && + !bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + // If hlo is the root instruction, we cannot change its output, so folding + // can only happen when it supports mixed precision so that we can change + // its operands. + return Status::OK(); + } + return TryFoldBF16Conversions(hlo); +} + +StatusOr BFloat16ConversionFolding::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h new file mode 100644 index 0000000000..c939838709 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which folds F32 <-> BF16 conversions to their operands or users, when +// it is supported by the backend. +// +// This pass follows the passed-in backend-specific BF16 support rules, but can +// introduce mixed precision in individual HLOs which breaks the assumption of +// some other HLO passes. So it should be used at the end of the HLO +// optimization pipeline followed by a DCE pass. If other passes are needed +// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the +// changed made by this pass. +class BFloat16ConversionFolding : public HloPassInterface { + public: + explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16ConversionFolding() override = default; + tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + + // Run BF16 conversion folding on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc new file mode 100644 index 0000000000..cb37759439 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -0,0 +1,209 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16ConversionFoldingTest : public HloTestBase { + protected: + bool FoldConversions(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16ConversionFolding fold(&bfloat16_support_); + StatusOr result = fold.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); + EXPECT_EQ(add1->operand(0), add0); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kMultiply, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kSubtract, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b)); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a, convert0})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert1); + EXPECT_EQ(gte->shape().element_type(), F32); + EXPECT_EQ(tuple->operand(1), convert0); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc new file mode 100644 index 0000000000..b032c040e8 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -0,0 +1,351 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16NormalizationVisitor(HloComputation* computation, + const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16NormalizationVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts + // conversions between F32 and BF16 to make it supported. + Status HandleInstruction(HloInstruction* hlo); + + // Inserts a conversion HLO that changes the given HLO's output type. + Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, + HloComputation* computation); + + // Changes the output type to the specified type, then inserts a conversion + // to the original type. + Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo, + PrimitiveType to, + HloComputation* computation); + + // Inserts a conversion HLO that changes the given HLO's operand type. + Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx, + PrimitiveType to, + HloComputation* computation); + + // Inserts conversion HLOs to replace the called computations' BF16 + // operands/outputs to F32. + Status ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + bool is_root = computation->root_instruction() == hlo; + std::vector materialized_users = hlo->users(); + // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith. + auto convert = computation->AddInstruction( + HloInstruction::CreateConvert(hlo->shape(), hlo)); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert)); + } + if (is_root) { + computation->set_root_instruction(convert); + } + convert->mutable_shape()->set_element_type(to); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + auto original_type = hlo->shape().element_type(); + hlo->mutable_shape()->set_element_type(to); + return InsertConvertAfterOutput(hlo, original_type, computation); +} + +Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( + HloInstruction* hlo, int64 operand_idx, PrimitiveType to, + HloComputation* computation) { + auto operand = hlo->mutable_operand(operand_idx); + auto convert = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(operand->shape(), to), operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps) { + std::map cloned_computations; + for (auto& comp : bf16_called_comps) { + auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); + cloned_computations[comp] = cloned; + changed_ = true; + } + hlo->ReplaceCalledComputations([&](HloComputation* comp) { + auto it = cloned_computations.find(comp); + if (it != cloned_computations.end()) { + return it->second; + } + return comp; + }); + for (auto& comp_pair : cloned_computations) { + auto comp = comp_pair.second; + if (comp->root_instruction()->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + InsertConvertAfterOutput(comp->root_instruction(), F32, comp)); + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == BF16) { + // This changes the parameter to F32 then inserts a convert after it. + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(param, F32, comp)); + } + } + } + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + if (!ShapeUtil::IsTuple(crs->shape())) { + return HandleInstruction(crs); + } + + std::vector operand_types(crs->operand_count()); + std::vector output_types(crs->operand_count()); + bool has_f32 = false; + bool has_bf16 = false; + bool has_bf16_output = false; + for (int64 i = 0; i < crs->operand_count(); ++i) { + operand_types[i] = crs->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + if (operand_types[i] == F32 || output_types[i] == F32) { + has_f32 = true; + } else if (operand_types[i] == BF16) { + has_bf16 = true; + } + if (output_types[i] == BF16) { + has_bf16 = true; + has_bf16_output = true; + } + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + if (operand_types[i] != BF16) { + continue; + } + if (bfloat16_support_->SupportsBF16Operand(*crs, i) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + continue; + } + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + has_f32 = true; + } + + if (!has_bf16_output) { + return Status::OK(); + } + + if (bfloat16_support_->SupportsBF16Output(*crs) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + return Status::OK(); + } + + std::vector output_elements(crs->operand_count()); + auto original_shape = crs->shape(); + for (int64 i = 0; i < crs->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + if (output_types[i] != BF16) { + output_elements[i] = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + continue; + } + subshape->set_element_type(F32); + auto gte = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + output_elements[i] = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(*subshape, BF16), gte)); + } + auto tuple = computation_->AddInstruction( + HloInstruction::CreateTuple(output_elements)); + + std::vector materialized_users = crs->users(); + // Use the crs' shape temporarily, in order to pass checks in + // ReplaceUseWith. + *tuple->mutable_shape() = crs->shape(); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + } + *tuple->mutable_shape() = original_shape; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { + std::vector bf16_operands; + std::vector f32_operands; + bool has_f32 = false; + bool has_bf16 = false; + + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == F32) { + f32_operands.push_back(i); + has_f32 = true; + } else if (hlo->operand(i)->shape().element_type() == BF16) { + bf16_operands.push_back(i); + has_bf16 = true; + } + } + + if (hlo->shape().element_type() == F32) { + has_f32 = true; + } else if (hlo->shape().element_type() == BF16) { + has_bf16 = true; + } + + std::vector bf16_called_comps; + for (auto* comp : hlo->called_computations()) { + bool comp_has_bf16 = false; + if (comp->root_instruction()->shape().element_type() == F32) { + has_f32 = true; + } else if (comp->root_instruction()->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == F32) { + has_f32 = true; + } else if (param->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + } + if (comp_has_bf16) { + bf16_called_comps.push_back(comp); + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 && + has_f32) { + // Resolve unsupported mixed precision. + // + // See if we can change everything to BF16. + if (hlo->called_computations().empty() && + hlo->shape().element_type() == BF16) { + bool can_use_bf16 = true; + for (int i : f32_operands) { + if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i) && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + continue; + } + can_use_bf16 = false; + break; + } + if (can_use_bf16) { + for (int i : f32_operands) { + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, computation_)); + } + return Status::OK(); + } + } + if (hlo->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + for (int i : bf16_operands) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + return ConvertCalledComputations(hlo, bf16_called_comps); + } + + for (int i : bf16_operands) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + } + + if (hlo->shape().element_type() == BF16 && + !bfloat16_support_->SupportsBF16Output(*hlo)) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { + // Do not change instructions related to entry and exit of a computation, + // tuples, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + return HandleInstruction(hlo); +} + +StatusOr BFloat16Normalization::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeComputationPostOrder()) { + if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES(2, + "BFloat16Normalization::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h new file mode 100644 index 0000000000..2a60fe0af3 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not +// support BF16 input/output or mixed precision, according to the passed-in +// backend-specific BF16 support rules. +class BFloat16Normalization : public HloPassInterface { + public: + explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16Normalization() override = default; + tensorflow::StringPiece name() const override { return "bf16-normalization"; } + + // Run BF16 normalization on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +// A pass that unconditionally removes the mixed F32/BF16 uses in HLO +// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike +// BFloat16Normalization, this pass does not use a backend-specific +// BFloat16Support, and does not change HLOs that have BF16 data if they do not +// use mixed precision; it removes mixed precision even if the backend supports +// it. This pass is used to make the HLO module valid for other HLO passes which +// do not support mixed precision. +class BFloat16MixedPrecisionRemoval : public HloPassInterface { + public: + BFloat16MixedPrecisionRemoval() {} + + ~BFloat16MixedPrecisionRemoval() override = default; + + tensorflow::StringPiece name() const override { + return "bf16-mixed-precision-removal"; + } + + // Run mixed precision removal on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override { + BFloat16Normalization normalization(&no_mixed_precision_support_); + return normalization.Run(module); + } + + private: + class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support { + public: + BFloat16SupportForMixedPrecisionRemoval() {} + + ~BFloat16SupportForMixedPrecisionRemoval() override = default; + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + return true; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return false; + } + } no_mixed_precision_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc new file mode 100644 index 0000000000..66c3085842 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -0,0 +1,248 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16NormalizationTest : public HloTestBase { + protected: + bool Normalize(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16Normalization normalization(&bfloat16_support_); + StatusOr result = normalization.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16NormalizationTest, NoopIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b)); + + HloInstruction* mul1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), mul1); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b)); + + HloInstruction* sub1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), sub1); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { + Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); + + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); + auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0")); + auto reduce_comp_param1 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1")); + reduce_comp_builder.AddInstruction( + HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, + reduce_comp_param0, reduce_comp_param1)); + + auto module = CreateNewModule(); + auto reduce_computation = + module->AddEmbeddedComputation(reduce_comp_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_input_shape, "a")); + HloInstruction* init = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "init")); + HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( + f32_output_shape, input, init, {0}, reduce_computation)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), reduce); + EXPECT_EQ(reduce->called_computations().size(), 1); + EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(0) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(1) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->root_instruction() + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(0), input); + EXPECT_EQ(input->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(1)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc new file mode 100644 index 0000000000..3fd9e24601 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + CHECK_EQ(operand_index, 0); + return hlo.operand(0)->shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + return hlo.shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + default: + break; + } + return false; +} + +/* static */ +bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index) { + switch (hlo.opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kBroadcast: + case HloOpcode::kClamp: + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kGetTupleElement: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return true; + case HloOpcode::kDynamicSlice: + return operand_index == 0; + case HloOpcode::kDynamicUpdateSlice: + return operand_index == 0 || operand_index == 1; + case HloOpcode::kSelect: + return operand_index == 1 || operand_index == 2; + default: + break; + } + return false; +} + +bool BFloat16Support::EffectiveOperandPrecisionIsBF16( + const HloInstruction& hlo, int64 operand_index) const { + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h new file mode 100644 index 0000000000..29f662d22b --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +class BFloat16Support { + public: + BFloat16Support() {} + virtual ~BFloat16Support() {} + + // Returns whether the backend supports BF16 operand for the HLO instruction + // at the given index. + virtual bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const; + + // Returns whether the backend supports BF16 output for the HLO instruction. + virtual bool SupportsBF16Output(const HloInstruction& hlo) const; + + // Returns whether the backend support mixed precision: the operands, output, + // and parameters/output of the called computations can have different + // precisions (BF16 and F32). + virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const; + + // Returns whether the given HLO inherits its BF16 operand precision at the + // given index, so even if the output is F32, elements in the output that + // depend on the BF16 operand will still have BF16 effective precision even if + // they have F32 format. Similarly, this also means if the output is BF16 then + // increasing the operand precision from BF16 to F32 will not change the + // output. This typically includes HLOs that pass elements from the operand to + // the output without arithmetic operations. + static bool EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index); + + // Returns if the backend only uses BF16 precision for the operand at the + // specified index, even if the operand is F32. + virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, + int64 operand_index) const; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 0e4437b73b..0981f1f4fe 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1805,7 +1805,8 @@ void HloInstruction::RemoveUser(HloInstruction* user) { Status HloInstruction::ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer) { - TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape())) + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); @@ -1828,8 +1829,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); - TF_RET_CHECK( - ShapeUtil::Compatible(old_operand->shape(), new_operand->shape())) + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) << old_operand->shape().ShortDebugString() << " is not compatible with " << new_operand->shape().ShortDebugString(); operands_[operand_num] = new_operand; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 04d4656546..e2b3bb9d71 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -164,6 +166,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == broadcast->dimensions().size()); for (int64 operand_dimension = 0; @@ -178,6 +182,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); return tensorflow::Status::OK(); @@ -359,13 +365,122 @@ Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { batch_norm_grad->feature_index())); } +namespace { + +// Checks that the instruction does not have mixed precision floating point +// inputs. +Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { + switch (instruction->opcode()) { + // White list the following opcodes for mixed-precision check, because they + // involve data pass through or grouping via tuples, where the precisions + // of buffers can be different. + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kCustomCall: + case HloOpcode::kFusion: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReducePrecision: + case HloOpcode::kSelect: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + break; + default: { + PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID; + for (auto operand : instruction->operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::ElementIsFloating(subshape)) { + return Status::OK(); + } + if (fp_type == PRIMITIVE_TYPE_INVALID) { + fp_type = subshape.element_type(); + } else if (fp_type != subshape.element_type()) { + return FailedPrecondition( + "Seen floating point types of different precisions in " + "%s, but mixed precision is disallowed.", + instruction->ToString().c_str()); + } + return Status::OK(); + })); + } + } + } + return Status::OK(); +} + +} // namespace + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const Shape& expected_shape) { - if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { + const Shape& inferred_shape) { + // If allow_mixed_precision_ is false, check if there are operands with + // different precisions. We need this check because ShapeInference allows + // mixed precision inputs. + if (!allow_mixed_precision_) { + TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction)); + } + + // Check if the output shape matches the expected shape. + bool compatible; + // We treat BF16 and F32 as compatible types if mixed precision is allowed, + // but only when the instruction defines the BF16/F32 buffer. + switch (instruction->opcode()) { + case HloOpcode::kSelect: + if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { + // Select only defines the top-level buffer, which in this case is the + // tuple, so we cannot allow mixed precision. + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } + break; + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed + // precision is disallowed. + case HloOpcode::kConstant: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kWhile: + // The above opcodes should match the expected shapes exactly. + compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); + break; + default: + if (allow_mixed_precision_) { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } else { + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + if (!compatible) { return InvalidArgument( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(expected_shape).c_str(), + ShapeUtil::HumanString(inferred_shape).c_str(), ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } @@ -373,14 +488,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status) { - if (!expected_shape_status.ok()) { - Status s = expected_shape_status.status(); + const StatusOr& inferred_shape_status) { + if (!inferred_shape_status.ok()) { + Status s = inferred_shape_status.status(); tensorflow::errors::AppendToMessage(&s, ", for instruction ", instruction->ToString()); return s; } - return CheckShape(instruction, expected_shape_status.ValueOrDie()); + return CheckShape(instruction, inferred_shape_status.ValueOrDie()); } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 26d53dec1e..7eccf834bb 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -27,6 +27,10 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: + explicit ShapeVerifier() : allow_mixed_precision_(false) {} + explicit ShapeVerifier(bool allow_mixed_precision) + : allow_mixed_precision_(allow_mixed_precision) {} + Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; @@ -81,14 +85,14 @@ class ShapeVerifier : public DfsHloVisitor { } protected: - // Check the instruction's shape against the given expected shape and return - // an appropriate error if there is a mismatch. + // Check the instruction's shape against the shape given by ShapeInference + // and return an appropriate error if there is a mismatch. Status CheckShape(const HloInstruction* instruction, - const Shape& expected_shape); + const Shape& inferred_shape); // Overload which takes a StatusOr to reduce boilerplate in the caller. Status CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status); + const StatusOr& inferred_shape_status); // Check a unary (binary, etc) instruction's shape against the inferred shape. Status CheckUnaryShape(const HloInstruction* instruction); @@ -99,19 +103,32 @@ class ShapeVerifier : public DfsHloVisitor { // Checks if the given two instructions shares the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2); + + private: + // Whether the inputs and output of an instruction can contain both F32s and + // BF16s. Tuples that include both F32s and BF16s are allowed regardless of + // this flag. + bool allow_mixed_precision_; }; // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloPassInterface { public: + using ShapeVerifierFactory = std::function()>; + // Uses standard shape inference. explicit HloVerifier() - : shape_verifier_factory_([] { return MakeUnique(); }) {} + : shape_verifier_factory_( + [] { return MakeUnique(false); }) {} + + explicit HloVerifier(bool allow_mixed_precision) + : shape_verifier_factory_([allow_mixed_precision] { + return MakeUnique(allow_mixed_precision); + }) {} // Uses custom shape verification. - explicit HloVerifier( - std::function()> shape_verifier_factory) + explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) : shape_verifier_factory_(std::move(shape_verifier_factory)) {} ~HloVerifier() override = default; @@ -129,7 +146,7 @@ class HloVerifier : public HloPassInterface { // expectations. This is a factory function because ShapeVerifier, Note that // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. - std::function()> shape_verifier_factory_; + ShapeVerifierFactory shape_verifier_factory_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 4ba6da6ccc..004889b5f2 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -209,7 +209,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, } // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + init_value_shape)) { return InvalidArgument( "Reduction function's accumulator shape differs from the " "init_value shape: %s vs %s", @@ -220,8 +221,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Check that the inputs can be passed in as the second argument. const Shape& input_element_shape = ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::Compatible(input_element_shape, - reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape differs from the " "input type element type: %s vs %s", @@ -231,7 +232,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Currently the accumulator and inputs must be the same type, // though that restriction could be relaxed. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape currently must " "match the result shape. Got %s vs %s", @@ -394,11 +396,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, dimension); } const Shape* arg_shape = nullptr; + PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; + element_type = arg_shape->element_type(); continue; } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { @@ -409,7 +413,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape).c_str()); } - if (arg_shape->element_type() != shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "cannot concatenate arrays with different element types: %s vs %s", PrimitiveType_Name(arg_shape->element_type()).c_str(), @@ -431,6 +435,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*shape).c_str(), dimension); } } + element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); } std::vector new_dimensions(arg_shape->dimensions().begin(), @@ -438,7 +443,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (size_t i = 1; i < arg_shapes.size(); ++i) { new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); } - return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions); + return ShapeUtil::MakeShape(element_type, new_dimensions); } /* static */ StatusOr ShapeInference::InferConvertShape( @@ -536,7 +541,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), padding_config.ShortDebugString().c_str()); } - if (operand_shape.element_type() != padding_value_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + padding_value_shape)) { return InvalidArgument( "the element types of the operands to pad do not match"); } @@ -548,7 +554,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, std::max(operand_shape.dimensions(i) - 1, 0LL) * padding_config.dimensions(i).interior_padding(); } - return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); + return ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), + dimensions); } // Current DotDimensionNumbers Requirements: @@ -673,7 +681,7 @@ Status ValidateDotDimensionNumbers( }; // Check if both element types are the same. - if (lhs.element_type() != rhs.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return fail("element types do not match"); } @@ -736,7 +744,8 @@ Status ValidateDotDimensionNumbers( dimensions.push_back(rhs.dimensions(i)); } } - Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions); + Shape result = ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -767,7 +776,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(rhs).c_str()); } } - return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions); + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + output_dimensions); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -829,6 +839,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // specified in broadcast_dimensions are then changed to match the // corresponding dimension size in smaller_shape. Shape output_shape(larger_shape); + output_shape.set_element_type( + ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape)); for (int i = 0; i < smaller_shape.dimensions_size(); ++i) { int64 dimension_to_match = broadcast_dimensions.at(i); @@ -878,7 +890,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "binary op %s with different element types: %s and %s", BinaryOperation_Name(operation).c_str(), @@ -897,10 +909,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } - if (ShapeUtil::Compatible(lhs, rhs)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) { // If the shapes are the same other than layout, the output shape is the // same (elementwise op). - return lhs; + return ShapeUtil::ChangeElementType( + lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -973,7 +986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions)); - if (lhs.element_type() == F32) { + if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else { return Unimplemented("complex component type not supported"); @@ -1078,12 +1091,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); - if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } if (!ShapeUtil::IsTuple(*arg_shapes[i]) && !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) { + ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; } @@ -1148,7 +1162,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( i, ShapeUtil::HumanString(parameter_shape).c_str()); } - if (parameter_shape.element_type() != arg_shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, + *arg_shape)) { return InvalidArgument( "mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s", @@ -1221,7 +1236,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " @@ -1230,7 +1246,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " @@ -1329,7 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1339,7 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1349,7 +1368,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1359,7 +1379,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1481,7 +1502,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(output_grad_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " @@ -1490,7 +1512,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " @@ -1499,7 +1522,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1508,7 +1532,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(var_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1569,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s", ShapeUtil::HumanString(lhs).c_str(), @@ -1714,8 +1739,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } - - return ShapeUtil::MakeShape(lhs.element_type(), dimensions); + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + dimensions); } /* static */ StatusOr ShapeInference::InferFftShape( @@ -1877,16 +1902,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } const Shape& operand_element_shape = ShapeUtil::MakeShape(operand_shape.element_type(), {}); - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(0))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(0))) { return InvalidArgument( "select function's first parameter shape currently must " "match the operand element shape. Got %s vs %s", ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(1))) { return InvalidArgument( "select function's second parameter shape currently must " "match the operand element shape. Got %s vs %s", @@ -1903,7 +1928,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( InferWindowOutputShape(operand_shape, window, operand_shape.element_type(), /*allow_negative_padding=*/false)); - if (!ShapeUtil::Compatible(source_shape, window_result_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, + window_result_shape)) { return InvalidArgument( "source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s)", @@ -2086,7 +2112,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } - if (operand_shape.element_type() != update_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + update_shape)) { return InvalidArgument( "dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s", @@ -2322,24 +2349,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); - if (!ShapeUtil::SameElementType(min, operand) || - !ShapeUtil::SameElementType(max, operand)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || + !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("clamp op with different operand types: %s, %s, %s", ShapeUtil::HumanString(min).c_str(), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::HumanString(max).c_str()); } - if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && - (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || + ShapeUtil::IsScalar(min)) && + (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) || + ShapeUtil::IsScalar(max)))) { return operand; } if (ShapeUtil::IsScalar(operand)) { - if (ShapeUtil::Compatible(min, max)) { - return min; + if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) { + return ShapeUtil::ChangeElementType(min, operand.element_type()); } else if (ShapeUtil::IsScalar(min)) { - return max; + return ShapeUtil::ChangeElementType(max, operand.element_type()); } else if (ShapeUtil::IsScalar(max)) { - return min; + return ShapeUtil::ChangeElementType(min, operand.element_type()); } } return Unimplemented( @@ -2352,7 +2381,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - if (!ShapeUtil::Compatible(on_true, on_false)) { + bool compatible; + if (ShapeUtil::IsTuple(on_true)) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + compatible = ShapeUtil::Compatible(on_true, on_false); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); + } + if (!compatible) { return InvalidArgument( "operands to select must be the same shape; got %s and %s", ShapeUtil::HumanString(on_true).c_str(), @@ -2367,7 +2404,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. - return on_true; + return ShapeUtil::ChangeElementType( + on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } else { return Unimplemented( "select operation with non-scalar predicate with dimensionality " diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d63e16ce2b..604e0173e7 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -630,6 +630,19 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringFpPrecision); + } + if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { + return CompatibleIgnoringElementType(lhs, rhs); + } + return false; +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 453d4ec047..d8a00880e9 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -211,6 +212,31 @@ class ShapeUtil { return lhs.element_type() == rhs.element_type(); } + // As SameElementType, but allows floating point types to have different + // precisions. + static bool SameElementTypeIgnoringFpPrecision(const Shape& a, + const Shape& b) { + if (ElementIsFloating(a) && ElementIsFloating(b)) { + return true; + } + return ShapeUtil::SameElementType(a, b); + } + + // Returns the higher-precision element type if a and b are both floating + // point types; otherwise, checks that that they have the same element type + // and returns it. + static PrimitiveType HigherPrecisionElementType(const Shape& a, + const Shape& b) { + if (SameElementType(a, b)) { + return a.element_type(); + } + CHECK(SameElementTypeIgnoringFpPrecision(a, b)); + return primitive_util::BitWidth(a.element_type()) < + primitive_util::BitWidth(b.element_type()) + ? b.element_type() + : a.element_type(); + } + // Returns true if the rank, dimension sizes, and element type are // identical. Layout is ignored. Tuple elements are compared recursively for // compatibility. @@ -221,6 +247,10 @@ class ShapeUtil { // compatibility. static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // As Compatible, but allow one of lhs and rhs to be BF16 while the other + // being F32. Tuple elements are compared recursively for compatibility. + static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 81ba7afb95..4db97d45b2 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -170,6 +170,18 @@ TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); + ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + +TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2}); + ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2}); @@ -184,6 +196,14 @@ TEST(ShapeUtilTest, CompatibleTuples) { EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2)); } +TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); @@ -193,6 +213,14 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } +TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});