]),
)
+cc_library(
+ name = "bfloat16_support",
+ srcs = ["bfloat16_support.cc"],
+ hdrs = ["bfloat16_support.h"],
+ deps = [
+ ":hlo",
+ ],
+)
+
+cc_library(
+ name = "bfloat16_conversion_folding",
+ srcs = ["bfloat16_conversion_folding.cc"],
+ hdrs = ["bfloat16_conversion_folding.h"],
+ deps = [
+ ":bfloat16_support",
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "bfloat16_conversion_folding_test",
+ srcs = ["bfloat16_conversion_folding_test.cc"],
+ deps = [
+ ":bfloat16_conversion_folding",
+ ":bfloat16_support",
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "bfloat16_normalization",
+ srcs = ["bfloat16_normalization.cc"],
+ hdrs = ["bfloat16_normalization.h"],
+ deps = [
+ ":bfloat16_support",
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "bfloat16_normalization_test",
+ srcs = ["bfloat16_normalization_test.cc"],
+ deps = [
+ ":bfloat16_normalization",
+ ":bfloat16_support",
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ ],
+)
+
cc_library(
name = "shape_inference",
srcs = ["shape_inference.cc"],
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
+ public:
+ explicit BFloat16ConversionFoldingVisitor(
+ HloComputation* computation, const BFloat16Support* bfloat16_support)
+ : computation_(computation), bfloat16_support_(bfloat16_support) {}
+
+ Status DefaultAction(HloInstruction* hlo) override;
+
+ static bool Run(HloComputation* computation,
+ const BFloat16Support* bfloat16_support) {
+ BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
+ TF_CHECK_OK(computation->Accept(&visitor));
+ return visitor.changed_;
+ }
+
+ private:
+ // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
+ // conversion as output, and folds them to the HLO itself if feasible.
+ Status TryFoldBF16Conversions(HloInstruction* hlo);
+
+ // Folds the F32 -> BF16 conversions from the HLO's output.
+ //
+ // Precondition: all of the HLO's users are F32 -> BF16 conversions.
+ Status FoldOutputConversions(HloInstruction* hlo);
+
+ // Folds the BF16 -> F32 conversion operand to the HLO.
+ //
+ // Precondition: the operand is a F32 -> BF16 conversion.
+ Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
+
+ HloComputation* computation_;
+ const BFloat16Support* bfloat16_support_;
+ bool changed_ = false;
+};
+
+Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
+ HloInstruction* hlo) {
+ std::vector<HloInstruction*> materialized_users = hlo->users();
+ hlo->mutable_shape()->set_element_type(BF16);
+ for (auto user : materialized_users) {
+ CHECK_EQ(user->opcode(), HloOpcode::kConvert);
+ TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
+ changed_ = true;
+ }
+ return Status::OK();
+}
+
+Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
+ HloInstruction* hlo, int64 operand_index) {
+ // The operand is a convert from BF16 to F32.
+ auto operand = hlo->mutable_operand(operand_index);
+ CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
+ TF_RETURN_IF_ERROR(
+ hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
+ changed_ = true;
+ return Status::OK();
+}
+
+Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
+ HloInstruction* hlo) {
+ std::vector<int64> bf16_to_f32_operands;
+ bool has_other_f32_operands = false;
+ for (int64 i = 0; i < hlo->operands().size(); ++i) {
+ auto operand = hlo->operand(i);
+ if (operand->shape().element_type() == F32) {
+ if (operand->opcode() == HloOpcode::kConvert &&
+ operand->operand(0)->shape().element_type() == BF16 &&
+ bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
+ // Operand is a convert from BF16 to F32 and we support BF16 input
+ // directly in the current HLO at the operand index.
+ bf16_to_f32_operands.push_back(i);
+ } else {
+ has_other_f32_operands = true;
+ }
+ continue;
+ }
+ }
+
+ bool fold_output_conversion = hlo->user_count() > 0 &&
+ hlo->shape().element_type() == F32 &&
+ bfloat16_support_->SupportsBF16Output(*hlo) &&
+ hlo != computation_->root_instruction();
+ if (fold_output_conversion) {
+ for (auto user : hlo->users()) {
+ if (user->opcode() == HloOpcode::kConvert &&
+ user->shape().element_type() == BF16) {
+ continue;
+ }
+ // We should not change the output type if any user is not a conversion
+ // from F32 to BF16.
+ fold_output_conversion = false;
+ break;
+ }
+ }
+
+ if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
+ if (has_other_f32_operands ||
+ (!fold_output_conversion && hlo->shape().element_type() == F32)) {
+ // Some of the operands/output will remain F32, but we cannot use mixed
+ // precisions, so we cannot do anything here.
+ return Status::OK();
+ }
+ }
+
+ if (fold_output_conversion) {
+ TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
+ }
+
+ for (int64 i : bf16_to_f32_operands) {
+ TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
+ }
+ return Status::OK();
+}
+
+Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
+ // Do not fold BF16 conversions for instructions related to tuples, entry and
+ // exit of a computation, fusion, convert, and control flow.
+ if (hlo->opcode() == HloOpcode::kTuple || //
+ hlo->opcode() == HloOpcode::kGetTupleElement || //
+ hlo->opcode() == HloOpcode::kInfeed || //
+ hlo->opcode() == HloOpcode::kOutfeed || //
+ hlo->opcode() == HloOpcode::kConstant || //
+ hlo->opcode() == HloOpcode::kParameter || //
+ hlo->opcode() == HloOpcode::kFusion || //
+ hlo->opcode() == HloOpcode::kConvert || //
+ hlo->opcode() == HloOpcode::kCall || //
+ hlo->opcode() == HloOpcode::kCustomCall || //
+ hlo->opcode() == HloOpcode::kWhile || //
+ hlo->opcode() == HloOpcode::kConditional) {
+ return Status::OK();
+ }
+ if (hlo == computation_->root_instruction() &&
+ !bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
+ // If hlo is the root instruction, we cannot change its output, so folding
+ // can only happen when it supports mixed precision so that we can change
+ // its operands.
+ return Status::OK();
+ }
+ return TryFoldBF16Conversions(hlo);
+}
+
+StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
+ XLA_VLOG_LINES(
+ 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ for (auto* comp : module->MakeNonfusionComputations()) {
+ if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
+ changed = true;
+ }
+ }
+ XLA_VLOG_LINES(
+ 2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
+
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// A pass which folds F32 <-> BF16 conversions to their operands or users, when
+// it is supported by the backend.
+//
+// This pass follows the passed-in backend-specific BF16 support rules, but can
+// introduce mixed precision in individual HLOs which breaks the assumption of
+// some other HLO passes. So it should be used at the end of the HLO
+// optimization pipeline followed by a DCE pass. If other passes are needed
+// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
+// changed made by this pass.
+class BFloat16ConversionFolding : public HloPassInterface {
+ public:
+ explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
+ : bfloat16_support_(bfloat16_support) {}
+
+ ~BFloat16ConversionFolding() override = default;
+ tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
+
+ // Run BF16 conversion folding on the given computation. Returns whether the
+ // computation was changed.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ const BFloat16Support* bfloat16_support_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+class TestBFloat16Support : public BFloat16Support {
+ public:
+ TestBFloat16Support() {}
+ ~TestBFloat16Support() override {}
+
+ bool SupportsBF16Operand(const HloInstruction& hlo,
+ int64 operand_index) const override {
+ if (hlo.opcode() == HloOpcode::kAdd ||
+ hlo.opcode() == HloOpcode::kSubtract ||
+ hlo.opcode() == HloOpcode::kTuple ||
+ hlo.opcode() == HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ return false;
+ }
+
+ bool SupportsBF16Output(const HloInstruction& hlo) const override {
+ if (hlo.opcode() == HloOpcode::kAdd ||
+ hlo.opcode() == HloOpcode::kSubtract ||
+ hlo.opcode() == HloOpcode::kTuple ||
+ hlo.opcode() == HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ return false;
+ }
+
+ bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
+ if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
+ hlo.opcode() == HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ return false;
+ }
+};
+
+class BFloat16ConversionFoldingTest : public HloTestBase {
+ protected:
+ bool FoldConversions(HloModule* module) {
+ TestBFloat16Support bfloat16_support_;
+ BFloat16ConversionFolding fold(&bfloat16_support_);
+ StatusOr<bool> result = fold.Run(module);
+ EXPECT_IS_OK(result.status());
+ return result.ValueOrDie();
+ }
+};
+
+TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, f32_shape, "b"));
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b));
+ HloInstruction* convert0 =
+ builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0));
+ HloInstruction* convert1 = builder.AddInstruction(
+ HloInstruction::CreateConvert(f32_shape, convert0));
+
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c));
+ builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(FoldConversions(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), add1);
+ EXPECT_EQ(add0->shape().element_type(), BF16);
+ EXPECT_EQ(add1->shape().element_type(), BF16);
+ EXPECT_EQ(add1->operand(0), add0);
+}
+
+TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, f32_shape, "b"));
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+ HloInstruction* mul0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b));
+ HloInstruction* convert0 =
+ builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0));
+ HloInstruction* convert1 = builder.AddInstruction(
+ HloInstruction::CreateConvert(f32_shape, convert0));
+
+ HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32_shape, HloOpcode::kMultiply, convert1, c));
+ HloInstruction* convert2 =
+ builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_FALSE(FoldConversions(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), convert2);
+ EXPECT_EQ(mul0->shape().element_type(), F32);
+ EXPECT_EQ(mul1->shape().element_type(), F32);
+ EXPECT_EQ(mul1->operand(0), convert1);
+}
+
+TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, f32_shape, "b"));
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+ HloInstruction* sub0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b));
+ HloInstruction* convert0 =
+ builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0));
+ HloInstruction* convert1 = builder.AddInstruction(
+ HloInstruction::CreateConvert(f32_shape, convert0));
+
+ HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32_shape, HloOpcode::kSubtract, convert1, c));
+ HloInstruction* convert2 =
+ builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_FALSE(FoldConversions(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), convert2);
+ EXPECT_EQ(sub0->shape().element_type(), F32);
+ EXPECT_EQ(sub1->shape().element_type(), F32);
+ EXPECT_EQ(sub1->operand(0), convert1);
+}
+
+TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, bf16_shape, "b"));
+ HloInstruction* convert0 =
+ builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b));
+
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({a, convert0}));
+ HloInstruction* gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
+ HloInstruction* convert1 =
+ builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_FALSE(FoldConversions(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), convert1);
+ EXPECT_EQ(gte->shape().element_type(), F32);
+ EXPECT_EQ(tuple->operand(1), convert0);
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
+ public:
+ explicit BFloat16NormalizationVisitor(HloComputation* computation,
+ const BFloat16Support* bfloat16_support)
+ : computation_(computation), bfloat16_support_(bfloat16_support) {}
+
+ Status DefaultAction(HloInstruction* hlo) override;
+
+ // Special handling for cross-replica-sum which can have a tuple output.
+ Status HandleCrossReplicaSum(HloInstruction* crs) override;
+
+ static bool Run(HloComputation* computation,
+ const BFloat16Support* bfloat16_support) {
+ BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
+ TF_CHECK_OK(computation->Accept(&visitor));
+ return visitor.changed_;
+ }
+
+ private:
+ // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
+ // conversions between F32 and BF16 to make it supported.
+ Status HandleInstruction(HloInstruction* hlo);
+
+ // Inserts a conversion HLO that changes the given HLO's output type.
+ Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
+ HloComputation* computation);
+
+ // Changes the output type to the specified type, then inserts a conversion
+ // to the original type.
+ Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
+ PrimitiveType to,
+ HloComputation* computation);
+
+ // Inserts a conversion HLO that changes the given HLO's operand type.
+ Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
+ PrimitiveType to,
+ HloComputation* computation);
+
+ // Inserts conversion HLOs to replace the called computations' BF16
+ // operands/outputs to F32.
+ Status ConvertCalledComputations(
+ HloInstruction* hlo,
+ tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
+
+ HloComputation* computation_;
+ const BFloat16Support* bfloat16_support_;
+ bool changed_ = false;
+};
+
+Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
+ HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
+ bool is_root = computation->root_instruction() == hlo;
+ std::vector<HloInstruction*> materialized_users = hlo->users();
+ // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
+ auto convert = computation->AddInstruction(
+ HloInstruction::CreateConvert(hlo->shape(), hlo));
+ for (auto* user : materialized_users) {
+ TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
+ }
+ if (is_root) {
+ computation->set_root_instruction(convert);
+ }
+ convert->mutable_shape()->set_element_type(to);
+ changed_ = true;
+ return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
+ HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
+ auto original_type = hlo->shape().element_type();
+ hlo->mutable_shape()->set_element_type(to);
+ return InsertConvertAfterOutput(hlo, original_type, computation);
+}
+
+Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
+ HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
+ HloComputation* computation) {
+ auto operand = hlo->mutable_operand(operand_idx);
+ auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(operand->shape(), to), operand));
+ TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
+ changed_ = true;
+ return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::ConvertCalledComputations(
+ HloInstruction* hlo,
+ tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
+ std::map<HloComputation*, HloComputation*> cloned_computations;
+ for (auto& comp : bf16_called_comps) {
+ auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
+ cloned_computations[comp] = cloned;
+ changed_ = true;
+ }
+ hlo->ReplaceCalledComputations([&](HloComputation* comp) {
+ auto it = cloned_computations.find(comp);
+ if (it != cloned_computations.end()) {
+ return it->second;
+ }
+ return comp;
+ });
+ for (auto& comp_pair : cloned_computations) {
+ auto comp = comp_pair.second;
+ if (comp->root_instruction()->shape().element_type() == BF16) {
+ TF_RETURN_IF_ERROR(
+ InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
+ }
+ for (auto* param : comp->parameter_instructions()) {
+ if (param->shape().element_type() == BF16) {
+ // This changes the parameter to F32 then inserts a convert after it.
+ TF_RETURN_IF_ERROR(
+ ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
+ HloInstruction* crs) {
+ if (!ShapeUtil::IsTuple(crs->shape())) {
+ return HandleInstruction(crs);
+ }
+
+ std::vector<PrimitiveType> operand_types(crs->operand_count());
+ std::vector<PrimitiveType> output_types(crs->operand_count());
+ bool has_f32 = false;
+ bool has_bf16 = false;
+ bool has_bf16_output = false;
+ for (int64 i = 0; i < crs->operand_count(); ++i) {
+ operand_types[i] = crs->operand(i)->shape().element_type();
+ output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type();
+ if (operand_types[i] == F32 || output_types[i] == F32) {
+ has_f32 = true;
+ } else if (operand_types[i] == BF16) {
+ has_bf16 = true;
+ }
+ if (output_types[i] == BF16) {
+ has_bf16 = true;
+ has_bf16_output = true;
+ }
+ }
+
+ for (int64 i = 0; i < crs->operand_count(); ++i) {
+ if (operand_types[i] != BF16) {
+ continue;
+ }
+ if (bfloat16_support_->SupportsBF16Operand(*crs, i) &&
+ (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
+ continue;
+ }
+ TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
+ has_f32 = true;
+ }
+
+ if (!has_bf16_output) {
+ return Status::OK();
+ }
+
+ if (bfloat16_support_->SupportsBF16Output(*crs) &&
+ (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
+ return Status::OK();
+ }
+
+ std::vector<HloInstruction*> output_elements(crs->operand_count());
+ auto original_shape = crs->shape();
+ for (int64 i = 0; i < crs->operand_count(); ++i) {
+ auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i});
+ if (output_types[i] != BF16) {
+ output_elements[i] = computation_->AddInstruction(
+ HloInstruction::CreateGetTupleElement(*subshape, crs, i));
+ continue;
+ }
+ subshape->set_element_type(F32);
+ auto gte = computation_->AddInstruction(
+ HloInstruction::CreateGetTupleElement(*subshape, crs, i));
+ output_elements[i] =
+ computation_->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(*subshape, BF16), gte));
+ }
+ auto tuple = computation_->AddInstruction(
+ HloInstruction::CreateTuple(output_elements));
+
+ std::vector<HloInstruction*> materialized_users = crs->users();
+ // Use the crs' shape temporarily, in order to pass checks in
+ // ReplaceUseWith.
+ *tuple->mutable_shape() = crs->shape();
+ for (auto* user : materialized_users) {
+ TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple));
+ }
+ *tuple->mutable_shape() = original_shape;
+ return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
+ std::vector<int64> bf16_operands;
+ std::vector<int64> f32_operands;
+ bool has_f32 = false;
+ bool has_bf16 = false;
+
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
+ if (hlo->operand(i)->shape().element_type() == F32) {
+ f32_operands.push_back(i);
+ has_f32 = true;
+ } else if (hlo->operand(i)->shape().element_type() == BF16) {
+ bf16_operands.push_back(i);
+ has_bf16 = true;
+ }
+ }
+
+ if (hlo->shape().element_type() == F32) {
+ has_f32 = true;
+ } else if (hlo->shape().element_type() == BF16) {
+ has_bf16 = true;
+ }
+
+ std::vector<HloComputation*> bf16_called_comps;
+ for (auto* comp : hlo->called_computations()) {
+ bool comp_has_bf16 = false;
+ if (comp->root_instruction()->shape().element_type() == F32) {
+ has_f32 = true;
+ } else if (comp->root_instruction()->shape().element_type() == BF16) {
+ has_bf16 = true;
+ comp_has_bf16 = true;
+ }
+ for (auto* param : comp->parameter_instructions()) {
+ if (param->shape().element_type() == F32) {
+ has_f32 = true;
+ } else if (param->shape().element_type() == BF16) {
+ has_bf16 = true;
+ comp_has_bf16 = true;
+ }
+ }
+ if (comp_has_bf16) {
+ bf16_called_comps.push_back(comp);
+ }
+ }
+
+ if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 &&
+ has_f32) {
+ // Resolve unsupported mixed precision.
+ //
+ // See if we can change everything to BF16.
+ if (hlo->called_computations().empty() &&
+ hlo->shape().element_type() == BF16) {
+ bool can_use_bf16 = true;
+ for (int i : f32_operands) {
+ if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
+ i) &&
+ bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
+ continue;
+ }
+ can_use_bf16 = false;
+ break;
+ }
+ if (can_use_bf16) {
+ for (int i : f32_operands) {
+ TF_RETURN_IF_ERROR(
+ InsertConvertBeforeOperand(hlo, i, BF16, computation_));
+ }
+ return Status::OK();
+ }
+ }
+ if (hlo->shape().element_type() == BF16) {
+ TF_RETURN_IF_ERROR(
+ ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
+ }
+ for (int i : bf16_operands) {
+ TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
+ }
+ return ConvertCalledComputations(hlo, bf16_called_comps);
+ }
+
+ for (int i : bf16_operands) {
+ if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
+ TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
+ }
+ }
+
+ if (hlo->shape().element_type() == BF16 &&
+ !bfloat16_support_->SupportsBF16Output(*hlo)) {
+ TF_RETURN_IF_ERROR(
+ ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
+ }
+
+ return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
+ // Do not change instructions related to entry and exit of a computation,
+ // tuples, fusion, convert, and control flow.
+ if (hlo->opcode() == HloOpcode::kTuple || //
+ hlo->opcode() == HloOpcode::kGetTupleElement || //
+ hlo->opcode() == HloOpcode::kInfeed || //
+ hlo->opcode() == HloOpcode::kOutfeed || //
+ hlo->opcode() == HloOpcode::kConstant || //
+ hlo->opcode() == HloOpcode::kParameter || //
+ hlo->opcode() == HloOpcode::kFusion || //
+ hlo->opcode() == HloOpcode::kConvert || //
+ hlo->opcode() == HloOpcode::kCall || //
+ hlo->opcode() == HloOpcode::kCustomCall || //
+ hlo->opcode() == HloOpcode::kWhile || //
+ hlo->opcode() == HloOpcode::kConditional) {
+ return Status::OK();
+ }
+ return HandleInstruction(hlo);
+}
+
+StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
+ XLA_VLOG_LINES(
+ 2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ for (auto* comp : module->MakeComputationPostOrder()) {
+ if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
+ changed = true;
+ }
+ }
+ XLA_VLOG_LINES(2,
+ "BFloat16Normalization::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
+
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
+// support BF16 input/output or mixed precision, according to the passed-in
+// backend-specific BF16 support rules.
+class BFloat16Normalization : public HloPassInterface {
+ public:
+ explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
+ : bfloat16_support_(bfloat16_support) {}
+
+ ~BFloat16Normalization() override = default;
+ tensorflow::StringPiece name() const override { return "bf16-normalization"; }
+
+ // Run BF16 normalization on the given computation. Returns whether the
+ // computation was changed.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ const BFloat16Support* bfloat16_support_;
+};
+
+// A pass that unconditionally removes the mixed F32/BF16 uses in HLO
+// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike
+// BFloat16Normalization, this pass does not use a backend-specific
+// BFloat16Support, and does not change HLOs that have BF16 data if they do not
+// use mixed precision; it removes mixed precision even if the backend supports
+// it. This pass is used to make the HLO module valid for other HLO passes which
+// do not support mixed precision.
+class BFloat16MixedPrecisionRemoval : public HloPassInterface {
+ public:
+ BFloat16MixedPrecisionRemoval() {}
+
+ ~BFloat16MixedPrecisionRemoval() override = default;
+
+ tensorflow::StringPiece name() const override {
+ return "bf16-mixed-precision-removal";
+ }
+
+ // Run mixed precision removal on the given computation. Returns whether the
+ // computation was changed.
+ StatusOr<bool> Run(HloModule* module) override {
+ BFloat16Normalization normalization(&no_mixed_precision_support_);
+ return normalization.Run(module);
+ }
+
+ private:
+ class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support {
+ public:
+ BFloat16SupportForMixedPrecisionRemoval() {}
+
+ ~BFloat16SupportForMixedPrecisionRemoval() override = default;
+
+ bool SupportsBF16Operand(const HloInstruction& hlo,
+ int64 operand_index) const override {
+ return true;
+ }
+
+ bool SupportsBF16Output(const HloInstruction& hlo) const override {
+ return true;
+ }
+
+ bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
+ return false;
+ }
+ } no_mixed_precision_support_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+class TestBFloat16Support : public BFloat16Support {
+ public:
+ TestBFloat16Support() {}
+ ~TestBFloat16Support() override {}
+
+ bool SupportsBF16Operand(const HloInstruction& hlo,
+ int64 operand_index) const override {
+ if (hlo.opcode() == HloOpcode::kAdd ||
+ hlo.opcode() == HloOpcode::kSubtract ||
+ hlo.opcode() == HloOpcode::kReduce ||
+ hlo.opcode() == HloOpcode::kTuple ||
+ hlo.opcode() == HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ return false;
+ }
+
+ bool SupportsBF16Output(const HloInstruction& hlo) const override {
+ if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
+ hlo.opcode() == HloOpcode::kSubtract ||
+ hlo.opcode() == HloOpcode::kTuple ||
+ hlo.opcode() == HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ return false;
+ }
+
+ bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
+ if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
+ hlo.opcode() == HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ return false;
+ }
+};
+
+class BFloat16NormalizationTest : public HloTestBase {
+ protected:
+ bool Normalize(HloModule* module) {
+ TestBFloat16Support bfloat16_support_;
+ BFloat16Normalization normalization(&bfloat16_support_);
+ StatusOr<bool> result = normalization.Run(module);
+ EXPECT_IS_OK(result.status());
+ return result.ValueOrDie();
+ }
+};
+
+TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, bf16_shape, "b"));
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b));
+
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_FALSE(Normalize(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), add1);
+ EXPECT_EQ(add0->shape().element_type(), BF16);
+ EXPECT_EQ(add1->shape().element_type(), F32);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, bf16_shape, "b"));
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+ HloInstruction* mul0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b));
+
+ HloInstruction* mul1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(Normalize(module.get()));
+
+ EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
+ EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
+ EXPECT_EQ(mul0->shape().element_type(), F32);
+ EXPECT_EQ(mul1->shape().element_type(), F32);
+ EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, bf16_shape, "b"));
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+ HloInstruction* sub0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b));
+
+ HloInstruction* sub1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(Normalize(module.get()));
+
+ EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
+ EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
+ EXPECT_EQ(sub0->shape().element_type(), F32);
+ EXPECT_EQ(sub1->shape().element_type(), F32);
+ EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
+ Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
+
+ Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
+ auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0"));
+ auto reduce_comp_param1 = reduce_comp_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1"));
+ reduce_comp_builder.AddInstruction(
+ HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd,
+ reduce_comp_param0, reduce_comp_param1));
+
+ auto module = CreateNewModule();
+ auto reduce_computation =
+ module->AddEmbeddedComputation(reduce_comp_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_input_shape, "a"));
+ HloInstruction* init = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, bf16_scalar_shape, "init"));
+ HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
+ f32_output_shape, input, init, {0}, reduce_computation));
+
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(Normalize(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), reduce);
+ EXPECT_EQ(reduce->called_computations().size(), 1);
+ EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2);
+ EXPECT_EQ(reduce->called_computations()[0]
+ ->parameter_instruction(0)
+ ->shape()
+ .element_type(),
+ F32);
+ EXPECT_EQ(reduce->called_computations()[0]
+ ->parameter_instruction(1)
+ ->shape()
+ .element_type(),
+ F32);
+ EXPECT_EQ(reduce->called_computations()[0]
+ ->root_instruction()
+ ->shape()
+ .element_type(),
+ F32);
+ EXPECT_EQ(reduce->shape().element_type(), F32);
+ EXPECT_EQ(reduce->operand(0), input);
+ EXPECT_EQ(input->shape().element_type(), F32);
+ EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert);
+ EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "a"));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, bf16_shape, "b"));
+
+ HloInstruction* crs =
+ builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
+ HloInstruction* gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(Normalize(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), gte);
+ EXPECT_EQ(gte->shape().element_type(), BF16);
+ EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
+ EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+namespace xla {
+
+bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
+ int64 operand_index) const {
+ switch (hlo.opcode()) {
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kTuple:
+ case HloOpcode::kWhile:
+ return true;
+ case HloOpcode::kConvert:
+ CHECK_EQ(operand_index, 0);
+ return hlo.operand(0)->shape().element_type() == BF16;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
+ switch (hlo.opcode()) {
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kTuple:
+ case HloOpcode::kWhile:
+ return true;
+ case HloOpcode::kConvert:
+ return hlo.shape().element_type() == BF16;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
+ switch (hlo.opcode()) {
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kTuple:
+ case HloOpcode::kWhile:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+/* static */
+bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
+ const HloInstruction& hlo, int64 operand_index) {
+ switch (hlo.opcode()) {
+ case HloOpcode::kAbs:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kClamp:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kCopy:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kPad:
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kSlice:
+ case HloOpcode::kSort:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kTuple:
+ return true;
+ case HloOpcode::kDynamicSlice:
+ return operand_index == 0;
+ case HloOpcode::kDynamicUpdateSlice:
+ return operand_index == 0 || operand_index == 1;
+ case HloOpcode::kSelect:
+ return operand_index == 1 || operand_index == 2;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
+ const HloInstruction& hlo, int64 operand_index) const {
+ return false;
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+namespace xla {
+
+class BFloat16Support {
+ public:
+ BFloat16Support() {}
+ virtual ~BFloat16Support() {}
+
+ // Returns whether the backend supports BF16 operand for the HLO instruction
+ // at the given index.
+ virtual bool SupportsBF16Operand(const HloInstruction& hlo,
+ int64 operand_index) const;
+
+ // Returns whether the backend supports BF16 output for the HLO instruction.
+ virtual bool SupportsBF16Output(const HloInstruction& hlo) const;
+
+ // Returns whether the backend support mixed precision: the operands, output,
+ // and parameters/output of the called computations can have different
+ // precisions (BF16 and F32).
+ virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const;
+
+ // Returns whether the given HLO inherits its BF16 operand precision at the
+ // given index, so even if the output is F32, elements in the output that
+ // depend on the BF16 operand will still have BF16 effective precision even if
+ // they have F32 format. Similarly, this also means if the output is BF16 then
+ // increasing the operand precision from BF16 to F32 will not change the
+ // output. This typically includes HLOs that pass elements from the operand to
+ // the output without arithmetic operations.
+ static bool EffectiveOperandPrecisionIsOutputPrecision(
+ const HloInstruction& hlo, int64 operand_index);
+
+ // Returns if the backend only uses BF16 precision for the operand at the
+ // specified index, even if the operand is F32.
+ virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
+ int64 operand_index) const;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
HloInstruction* new_producer) {
- TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape()))
+ TF_RET_CHECK(
+ ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< "this shape: " << ShapeUtil::HumanString(shape())
<< ", replacement shape: "
<< ShapeUtil::HumanString(new_producer->shape());
TF_RET_CHECK(operand_num >= 0);
TF_RET_CHECK(operand_num < operand_count());
HloInstruction* old_operand = mutable_operand(operand_num);
- TF_RET_CHECK(
- ShapeUtil::Compatible(old_operand->shape(), new_operand->shape()))
+ TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
+ new_operand->shape()))
<< old_operand->shape().ShortDebugString() << " is not compatible with "
<< new_operand->shape().ShortDebugString();
operands_[operand_num] = new_operand;
limitations under the License.
==============================================================================*/
+#include <set>
+
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/errors.h"
// HLO broadcast has no exact analog at the proto level so there is no
// ShapeInference method. Check the output shape explicitly.
const Shape& operand_shape = broadcast->operand(0)->shape();
+ // Check for mixed precision.
+ TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape()));
TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
broadcast->dimensions().size());
for (int64 operand_dimension = 0;
}
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
+ // Check for mixed precision.
+ TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
return tensorflow::Status::OK();
batch_norm_grad->feature_index()));
}
+namespace {
+
+// Checks that the instruction does not have mixed precision floating point
+// inputs.
+Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
+ switch (instruction->opcode()) {
+ // White list the following opcodes for mixed-precision check, because they
+ // involve data pass through or grouping via tuples, where the precisions
+ // of buffers can be different.
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConstant:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kFusion:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kReducePrecision:
+ case HloOpcode::kSelect:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kTuple:
+ case HloOpcode::kWhile:
+ break;
+ default: {
+ PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
+ for (auto operand : instruction->operands()) {
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ operand->shape(),
+ [&](const Shape& subshape, const ShapeIndex& index) {
+ if (!ShapeUtil::ElementIsFloating(subshape)) {
+ return Status::OK();
+ }
+ if (fp_type == PRIMITIVE_TYPE_INVALID) {
+ fp_type = subshape.element_type();
+ } else if (fp_type != subshape.element_type()) {
+ return FailedPrecondition(
+ "Seen floating point types of different precisions in "
+ "%s, but mixed precision is disallowed.",
+ instruction->ToString().c_str());
+ }
+ return Status::OK();
+ }));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
- const Shape& expected_shape) {
- if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) {
+ const Shape& inferred_shape) {
+ // If allow_mixed_precision_ is false, check if there are operands with
+ // different precisions. We need this check because ShapeInference allows
+ // mixed precision inputs.
+ if (!allow_mixed_precision_) {
+ TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
+ }
+
+ // Check if the output shape matches the expected shape.
+ bool compatible;
+ // We treat BF16 and F32 as compatible types if mixed precision is allowed,
+ // but only when the instruction defines the BF16/F32 buffer.
+ switch (instruction->opcode()) {
+ case HloOpcode::kSelect:
+ if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) {
+ // Select only defines the top-level buffer, which in this case is the
+ // tuple, so we cannot allow mixed precision.
+ compatible =
+ ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ } else {
+ compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
+ instruction->shape(), inferred_shape);
+ }
+ break;
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kTuple:
+ // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
+ // precision is disallowed.
+ case HloOpcode::kConstant:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kBitcastConvert:
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kWhile:
+ // The above opcodes should match the expected shapes exactly.
+ compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ break;
+ default:
+ if (allow_mixed_precision_) {
+ compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
+ instruction->shape(), inferred_shape);
+ } else {
+ compatible =
+ ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ }
+ }
+ if (!compatible) {
return InvalidArgument(
"Expected instruction to have shape compatible with %s, actual "
"shape is %s:\n%s",
- ShapeUtil::HumanString(expected_shape).c_str(),
+ ShapeUtil::HumanString(inferred_shape).c_str(),
ShapeUtil::HumanString(instruction->shape()).c_str(),
instruction->ToString().c_str());
}
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
- const StatusOr<Shape>& expected_shape_status) {
- if (!expected_shape_status.ok()) {
- Status s = expected_shape_status.status();
+ const StatusOr<Shape>& inferred_shape_status) {
+ if (!inferred_shape_status.ok()) {
+ Status s = inferred_shape_status.status();
tensorflow::errors::AppendToMessage(&s, ", for instruction ",
instruction->ToString());
return s;
}
- return CheckShape(instruction, expected_shape_status.ValueOrDie());
+ return CheckShape(instruction, inferred_shape_status.ValueOrDie());
}
Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
+ explicit ShapeVerifier() : allow_mixed_precision_(false) {}
+ explicit ShapeVerifier(bool allow_mixed_precision)
+ : allow_mixed_precision_(allow_mixed_precision) {}
+
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
Status HandleClamp(HloInstruction* clamp) override;
}
protected:
- // Check the instruction's shape against the given expected shape and return
- // an appropriate error if there is a mismatch.
+ // Check the instruction's shape against the shape given by ShapeInference
+ // and return an appropriate error if there is a mismatch.
Status CheckShape(const HloInstruction* instruction,
- const Shape& expected_shape);
+ const Shape& inferred_shape);
// Overload which takes a StatusOr to reduce boilerplate in the caller.
Status CheckShape(const HloInstruction* instruction,
- const StatusOr<Shape>& expected_shape_status);
+ const StatusOr<Shape>& inferred_shape_status);
// Check a unary (binary, etc) instruction's shape against the inferred shape.
Status CheckUnaryShape(const HloInstruction* instruction);
// Checks if the given two instructions shares the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2);
+
+ private:
+ // Whether the inputs and output of an instruction can contain both F32s and
+ // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
+ // this flag.
+ bool allow_mixed_precision_;
};
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
class HloVerifier : public HloPassInterface {
public:
+ using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
+
// Uses standard shape inference.
explicit HloVerifier()
- : shape_verifier_factory_([] { return MakeUnique<ShapeVerifier>(); }) {}
+ : shape_verifier_factory_(
+ [] { return MakeUnique<ShapeVerifier>(false); }) {}
+
+ explicit HloVerifier(bool allow_mixed_precision)
+ : shape_verifier_factory_([allow_mixed_precision] {
+ return MakeUnique<ShapeVerifier>(allow_mixed_precision);
+ }) {}
// Uses custom shape verification.
- explicit HloVerifier(
- std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory)
+ explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
: shape_verifier_factory_(std::move(shape_verifier_factory)) {}
~HloVerifier() override = default;
// expectations. This is a factory function because ShapeVerifier, Note that
// ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object
// for each run of the verifier.
- std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory_;
+ ShapeVerifierFactory shape_verifier_factory_;
};
} // namespace xla
}
// Check that init_value's shape is suitable for reducer_shape.
- if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
+ init_value_shape)) {
return InvalidArgument(
"Reduction function's accumulator shape differs from the "
"init_value shape: %s vs %s",
// Check that the inputs can be passed in as the second argument.
const Shape& input_element_shape =
ShapeUtil::MakeShape(input_element_type, {});
- if (!ShapeUtil::Compatible(input_element_shape,
- reducer_shape.parameters(1))) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
+ reducer_shape.parameters(1))) {
return InvalidArgument(
"Reduction function's second parameter shape differs from the "
"input type element type: %s vs %s",
// Currently the accumulator and inputs must be the same type,
// though that restriction could be relaxed.
- if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
+ reducer_shape.parameters(1))) {
return InvalidArgument(
"Reduction function's second parameter shape currently must "
"match the result shape. Got %s vs %s",
dimension);
}
const Shape* arg_shape = nullptr;
+ PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
for (const Shape* shape : arg_shapes) {
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
if (!arg_shape) {
arg_shape = shape;
+ element_type = arg_shape->element_type();
continue;
}
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
ShapeUtil::HumanString(*shape).c_str());
}
- if (arg_shape->element_type() != shape->element_type()) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
return InvalidArgument(
"cannot concatenate arrays with different element types: %s vs %s",
PrimitiveType_Name(arg_shape->element_type()).c_str(),
ShapeUtil::HumanString(*shape).c_str(), dimension);
}
}
+ element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
}
std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
for (size_t i = 1; i < arg_shapes.size(); ++i) {
new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
}
- return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions);
+ return ShapeUtil::MakeShape(element_type, new_dimensions);
}
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
ShapeUtil::HumanString(operand_shape).c_str(),
padding_config.ShortDebugString().c_str());
}
- if (operand_shape.element_type() != padding_value_shape.element_type()) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
+ padding_value_shape)) {
return InvalidArgument(
"the element types of the operands to pad do not match");
}
std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
padding_config.dimensions(i).interior_padding();
}
- return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
+ return ShapeUtil::MakeShape(
+ ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
+ dimensions);
}
// Current DotDimensionNumbers Requirements:
};
// Check if both element types are the same.
- if (lhs.element_type() != rhs.element_type()) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return fail("element types do not match");
}
dimensions.push_back(rhs.dimensions(i));
}
}
- Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+ Shape result = ShapeUtil::MakeShape(
+ ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions);
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
ShapeUtil::HumanString(rhs).c_str());
}
}
- return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions);
+ return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
+ output_dimensions);
}
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
// specified in broadcast_dimensions are then changed to match the
// corresponding dimension size in smaller_shape.
Shape output_shape(larger_shape);
+ output_shape.set_element_type(
+ ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
int64 dimension_to_match = broadcast_dimensions.at(i);
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
- if (!ShapeUtil::SameElementType(lhs, rhs)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"binary op %s with different element types: %s and %s",
BinaryOperation_Name(operation).c_str(),
}
}
- if (ShapeUtil::Compatible(lhs, rhs)) {
+ if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
// If the shapes are the same other than layout, the output shape is the
// same (elementwise op).
- return lhs;
+ return ShapeUtil::ChangeElementType(
+ lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
}
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
TF_ASSIGN_OR_RETURN(const Shape& shape,
InferElementwiseBinaryOpShape(operation, lhs, rhs,
broadcast_dimensions));
- if (lhs.element_type() == F32) {
+ if (lhs.element_type() == F32 && rhs.element_type() == F32) {
return ShapeUtil::ChangeElementType(shape, C64);
} else {
return Unimplemented("complex component type not supported");
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
- if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) {
+ if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
continue;
}
if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
!ShapeUtil::IsTuple(*arg_shape) &&
- ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) {
+ ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
+ *arg_shape)) {
if (ShapeUtil::IsScalar(*arg_shapes[i])) {
continue;
}
i, ShapeUtil::HumanString(parameter_shape).c_str());
}
- if (parameter_shape.element_type() != arg_shape->element_type()) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
+ *arg_shape)) {
return InvalidArgument(
"mapped computation's parameter type has to match argument element "
"type; got parameter %d shape: %s, argument shape: %s",
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-training, "
"but the shape of offset factor is %s "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-training, "
"but the shape of scale factor is %s "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
PrimitiveType_Name(output_grad_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of output_grad is %s "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of scale factor is %s "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
- if (!ShapeUtil::SameElementType(var_shape, operand_shape)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
+ operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
- if (!ShapeUtil::SameElementType(lhs, rhs)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Convolution with different element types: %s and %s",
ShapeUtil::HumanString(lhs).c_str(),
dimensions[dnums.output_spatial_dimensions(i)] =
window_output_shape.dimensions(i);
}
-
- return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+ return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
+ dimensions);
}
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
}
const Shape& operand_element_shape =
ShapeUtil::MakeShape(operand_shape.element_type(), {});
- if (!ShapeUtil::Compatible(operand_element_shape,
- select_shape.parameters(0))) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
+ select_shape.parameters(0))) {
return InvalidArgument(
"select function's first parameter shape currently must "
"match the operand element shape. Got %s vs %s",
ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
ShapeUtil::HumanString(operand_element_shape).c_str());
}
- if (!ShapeUtil::Compatible(operand_element_shape,
- select_shape.parameters(1))) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
+ select_shape.parameters(1))) {
return InvalidArgument(
"select function's second parameter shape currently must "
"match the operand element shape. Got %s vs %s",
InferWindowOutputShape(operand_shape, window,
operand_shape.element_type(),
/*allow_negative_padding=*/false));
- if (!ShapeUtil::Compatible(source_shape, window_result_shape)) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
+ window_result_shape)) {
return InvalidArgument(
"source shape does not match the shape of window-reduced operand: "
"source(%s), window-reduced operand(%s)",
ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
}
- if (operand_shape.element_type() != update_shape.element_type()) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
+ update_shape)) {
return InvalidArgument(
"dynamic update slice update element type does not match argument. "
"operand.element_type: %s vs update.element_type: %s",
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
- if (!ShapeUtil::SameElementType(min, operand) ||
- !ShapeUtil::SameElementType(max, operand)) {
+ if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
+ !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("clamp op with different operand types: %s, %s, %s",
ShapeUtil::HumanString(min).c_str(),
ShapeUtil::HumanString(operand).c_str(),
ShapeUtil::HumanString(max).c_str());
}
- if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
- (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
+ if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
+ ShapeUtil::IsScalar(min)) &&
+ (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
+ ShapeUtil::IsScalar(max)))) {
return operand;
}
if (ShapeUtil::IsScalar(operand)) {
- if (ShapeUtil::Compatible(min, max)) {
- return min;
+ if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
+ return ShapeUtil::ChangeElementType(min, operand.element_type());
} else if (ShapeUtil::IsScalar(min)) {
- return max;
+ return ShapeUtil::ChangeElementType(max, operand.element_type());
} else if (ShapeUtil::IsScalar(max)) {
- return min;
+ return ShapeUtil::ChangeElementType(min, operand.element_type());
}
}
return Unimplemented(
// broadcast from all operands, not just the predicate.
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
const Shape& pred, const Shape& on_true, const Shape& on_false) {
- if (!ShapeUtil::Compatible(on_true, on_false)) {
+ bool compatible;
+ if (ShapeUtil::IsTuple(on_true)) {
+ // Select only defines the top-level buffer, so if it's a tuple, the two
+ // input must match exactly.
+ compatible = ShapeUtil::Compatible(on_true, on_false);
+ } else {
+ compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
+ }
+ if (!compatible) {
return InvalidArgument(
"operands to select must be the same shape; got %s and %s",
ShapeUtil::HumanString(on_true).c_str(),
// By this stage we know that pred's element type is PRED. Therefore, this
// check restricts pred to be a PRED scalar, or a PRED array with the same
// dimensions as on_true and on_false.
- return on_true;
+ return ShapeUtil::ChangeElementType(
+ on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
} else {
return Unimplemented(
"select operation with non-scalar predicate with dimensionality "
return SameDimensions(lhs, rhs);
}
+/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
+ const Shape& rhs) {
+ if (lhs.element_type() == TUPLE) {
+ return rhs.element_type() == TUPLE &&
+ ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ CompatibleIgnoringFpPrecision);
+ }
+ if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
+ return CompatibleIgnoringElementType(lhs, rhs);
+ }
+ return false;
+}
+
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
int64 dimension_number) {
return shape.dimensions(GetDimensionNumber(shape, dimension_number));
#include <string>
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
return lhs.element_type() == rhs.element_type();
}
+ // As SameElementType, but allows floating point types to have different
+ // precisions.
+ static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
+ const Shape& b) {
+ if (ElementIsFloating(a) && ElementIsFloating(b)) {
+ return true;
+ }
+ return ShapeUtil::SameElementType(a, b);
+ }
+
+ // Returns the higher-precision element type if a and b are both floating
+ // point types; otherwise, checks that that they have the same element type
+ // and returns it.
+ static PrimitiveType HigherPrecisionElementType(const Shape& a,
+ const Shape& b) {
+ if (SameElementType(a, b)) {
+ return a.element_type();
+ }
+ CHECK(SameElementTypeIgnoringFpPrecision(a, b));
+ return primitive_util::BitWidth(a.element_type()) <
+ primitive_util::BitWidth(b.element_type())
+ ? b.element_type()
+ : a.element_type();
+ }
+
// Returns true if the rank, dimension sizes, and element type are
// identical. Layout is ignored. Tuple elements are compared recursively for
// compatibility.
// compatibility.
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
+ // As Compatible, but allow one of lhs and rhs to be BF16 while the other
+ // being F32. Tuple elements are compared recursively for compatibility.
+ static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
+
// Returns whether the lhs and rhs shapes are identical protobufs.
static bool Equal(const Shape& lhs, const Shape& rhs);
EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2));
}
+TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) {
+ Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
+ Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
+ ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
+}
+
+TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) {
+ Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
+ Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2});
+ ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
+}
+
TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2});
EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2));
}
+TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
+ EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
+}
+
TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
}
+TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
+}
+
TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});