[XLA] An HLO pass that folds BF16 F32 conversions: if an HLO already supports BF16...
authorYuanzhong Xu <yuanzx@google.com>
Mon, 12 Feb 2018 19:26:22 +0000 (11:26 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 19:30:18 +0000 (11:30 -0800)
Also updates HloVerifier to allow mixed precision if requested. If an HLO has both both F32 and BF16 inputs, ShapeInference will use F32 as the output type.

PiperOrigin-RevId: 185407143

16 files changed:
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/bfloat16_conversion_folding.h [new file with mode: 0644]
tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/bfloat16_normalization.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/bfloat16_normalization.h [new file with mode: 0644]
tensorflow/compiler/xla/service/bfloat16_normalization_test.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/bfloat16_support.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/bfloat16_support.h [new file with mode: 0644]
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_verifier.cc
tensorflow/compiler/xla/service/hlo_verifier.h
tensorflow/compiler/xla/service/shape_inference.cc
tensorflow/compiler/xla/shape_util.cc
tensorflow/compiler/xla/shape_util.h
tensorflow/compiler/xla/shape_util_test.cc

index 93cc5ab..9f5f2f9 100644 (file)
@@ -44,6 +44,81 @@ filegroup(
 )
 
 cc_library(
+    name = "bfloat16_support",
+    srcs = ["bfloat16_support.cc"],
+    hdrs = ["bfloat16_support.h"],
+    deps = [
+        ":hlo",
+    ],
+)
+
+cc_library(
+    name = "bfloat16_conversion_folding",
+    srcs = ["bfloat16_conversion_folding.cc"],
+    hdrs = ["bfloat16_conversion_folding.h"],
+    deps = [
+        ":bfloat16_support",
+        ":hlo",
+        ":hlo_pass",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "bfloat16_conversion_folding_test",
+    srcs = ["bfloat16_conversion_folding_test.cc"],
+    deps = [
+        ":bfloat16_conversion_folding",
+        ":bfloat16_support",
+        ":hlo",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "bfloat16_normalization",
+    srcs = ["bfloat16_normalization.cc"],
+    hdrs = ["bfloat16_normalization.h"],
+    deps = [
+        ":bfloat16_support",
+        ":hlo",
+        ":hlo_pass",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "bfloat16_normalization_test",
+    srcs = ["bfloat16_normalization_test.cc"],
+    deps = [
+        ":bfloat16_normalization",
+        ":bfloat16_support",
+        ":hlo",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
     name = "shape_inference",
     srcs = ["shape_inference.cc"],
     hdrs = ["shape_inference.h"],
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
new file mode 100644 (file)
index 0000000..cde990e
--- /dev/null
@@ -0,0 +1,184 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
+ public:
+  explicit BFloat16ConversionFoldingVisitor(
+      HloComputation* computation, const BFloat16Support* bfloat16_support)
+      : computation_(computation), bfloat16_support_(bfloat16_support) {}
+
+  Status DefaultAction(HloInstruction* hlo) override;
+
+  static bool Run(HloComputation* computation,
+                  const BFloat16Support* bfloat16_support) {
+    BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
+    TF_CHECK_OK(computation->Accept(&visitor));
+    return visitor.changed_;
+  }
+
+ private:
+  // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
+  // conversion as output, and folds them to the HLO itself if feasible.
+  Status TryFoldBF16Conversions(HloInstruction* hlo);
+
+  // Folds the F32 -> BF16 conversions from the HLO's output.
+  //
+  // Precondition: all of the HLO's users are F32 -> BF16 conversions.
+  Status FoldOutputConversions(HloInstruction* hlo);
+
+  // Folds the BF16 -> F32 conversion operand to the HLO.
+  //
+  // Precondition: the operand is a F32 -> BF16 conversion.
+  Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
+
+  HloComputation* computation_;
+  const BFloat16Support* bfloat16_support_;
+  bool changed_ = false;
+};
+
+Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
+    HloInstruction* hlo) {
+  std::vector<HloInstruction*> materialized_users = hlo->users();
+  hlo->mutable_shape()->set_element_type(BF16);
+  for (auto user : materialized_users) {
+    CHECK_EQ(user->opcode(), HloOpcode::kConvert);
+    TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
+    changed_ = true;
+  }
+  return Status::OK();
+}
+
+Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
+    HloInstruction* hlo, int64 operand_index) {
+  // The operand is a convert from BF16 to F32.
+  auto operand = hlo->mutable_operand(operand_index);
+  CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
+  TF_RETURN_IF_ERROR(
+      hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
+  changed_ = true;
+  return Status::OK();
+}
+
+Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
+    HloInstruction* hlo) {
+  std::vector<int64> bf16_to_f32_operands;
+  bool has_other_f32_operands = false;
+  for (int64 i = 0; i < hlo->operands().size(); ++i) {
+    auto operand = hlo->operand(i);
+    if (operand->shape().element_type() == F32) {
+      if (operand->opcode() == HloOpcode::kConvert &&
+          operand->operand(0)->shape().element_type() == BF16 &&
+          bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
+        // Operand is a convert from BF16 to F32 and we support BF16 input
+        // directly in the current HLO at the operand index.
+        bf16_to_f32_operands.push_back(i);
+      } else {
+        has_other_f32_operands = true;
+      }
+      continue;
+    }
+  }
+
+  bool fold_output_conversion = hlo->user_count() > 0 &&
+                                hlo->shape().element_type() == F32 &&
+                                bfloat16_support_->SupportsBF16Output(*hlo) &&
+                                hlo != computation_->root_instruction();
+  if (fold_output_conversion) {
+    for (auto user : hlo->users()) {
+      if (user->opcode() == HloOpcode::kConvert &&
+          user->shape().element_type() == BF16) {
+        continue;
+      }
+      // We should not change the output type if any user is not a conversion
+      // from F32 to BF16.
+      fold_output_conversion = false;
+      break;
+    }
+  }
+
+  if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
+    if (has_other_f32_operands ||
+        (!fold_output_conversion && hlo->shape().element_type() == F32)) {
+      // Some of the operands/output will remain F32, but we cannot use mixed
+      // precisions, so we cannot do anything here.
+      return Status::OK();
+    }
+  }
+
+  if (fold_output_conversion) {
+    TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
+  }
+
+  for (int64 i : bf16_to_f32_operands) {
+    TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
+  }
+  return Status::OK();
+}
+
+Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
+  // Do not fold BF16 conversions for instructions related to tuples, entry and
+  // exit of a computation, fusion, convert, and control flow.
+  if (hlo->opcode() == HloOpcode::kTuple ||            //
+      hlo->opcode() == HloOpcode::kGetTupleElement ||  //
+      hlo->opcode() == HloOpcode::kInfeed ||           //
+      hlo->opcode() == HloOpcode::kOutfeed ||          //
+      hlo->opcode() == HloOpcode::kConstant ||         //
+      hlo->opcode() == HloOpcode::kParameter ||        //
+      hlo->opcode() == HloOpcode::kFusion ||           //
+      hlo->opcode() == HloOpcode::kConvert ||          //
+      hlo->opcode() == HloOpcode::kCall ||             //
+      hlo->opcode() == HloOpcode::kCustomCall ||       //
+      hlo->opcode() == HloOpcode::kWhile ||            //
+      hlo->opcode() == HloOpcode::kConditional) {
+    return Status::OK();
+  }
+  if (hlo == computation_->root_instruction() &&
+      !bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
+    // If hlo is the root instruction, we cannot change its output, so folding
+    // can only happen when it supports mixed precision so that we can change
+    // its operands.
+    return Status::OK();
+  }
+  return TryFoldBF16Conversions(hlo);
+}
+
+StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
+  XLA_VLOG_LINES(
+      2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  for (auto* comp : module->MakeNonfusionComputations()) {
+    if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
+      changed = true;
+    }
+  }
+  XLA_VLOG_LINES(
+      2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
new file mode 100644 (file)
index 0000000..c939838
--- /dev/null
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
+
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// A pass which folds F32 <-> BF16 conversions to their operands or users, when
+// it is supported by the backend.
+//
+// This pass follows the passed-in backend-specific BF16 support rules, but can
+// introduce mixed precision in individual HLOs which breaks the assumption of
+// some other HLO passes. So it should be used at the end of the HLO
+// optimization pipeline followed by a DCE pass. If other passes are needed
+// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
+// changed made by this pass.
+class BFloat16ConversionFolding : public HloPassInterface {
+ public:
+  explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
+      : bfloat16_support_(bfloat16_support) {}
+
+  ~BFloat16ConversionFolding() override = default;
+  tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
+
+  // Run BF16 conversion folding on the given computation. Returns whether the
+  // computation was changed.
+  StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+  const BFloat16Support* bfloat16_support_;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
new file mode 100644 (file)
index 0000000..cb37759
--- /dev/null
@@ -0,0 +1,209 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+class TestBFloat16Support : public BFloat16Support {
+ public:
+  TestBFloat16Support() {}
+  ~TestBFloat16Support() override {}
+
+  bool SupportsBF16Operand(const HloInstruction& hlo,
+                           int64 operand_index) const override {
+    if (hlo.opcode() == HloOpcode::kAdd ||
+        hlo.opcode() == HloOpcode::kSubtract ||
+        hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement) {
+      return true;
+    }
+    return false;
+  }
+
+  bool SupportsBF16Output(const HloInstruction& hlo) const override {
+    if (hlo.opcode() == HloOpcode::kAdd ||
+        hlo.opcode() == HloOpcode::kSubtract ||
+        hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement) {
+      return true;
+    }
+    return false;
+  }
+
+  bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
+    if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement) {
+      return true;
+    }
+    return false;
+  }
+};
+
+class BFloat16ConversionFoldingTest : public HloTestBase {
+ protected:
+  bool FoldConversions(HloModule* module) {
+    TestBFloat16Support bfloat16_support_;
+    BFloat16ConversionFolding fold(&bfloat16_support_);
+    StatusOr<bool> result = fold.Run(module);
+    EXPECT_IS_OK(result.status());
+    return result.ValueOrDie();
+  }
+};
+
+TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32_shape, "b"));
+  HloInstruction* c = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+  HloInstruction* add0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b));
+  HloInstruction* convert0 =
+      builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0));
+  HloInstruction* convert1 = builder.AddInstruction(
+      HloInstruction::CreateConvert(f32_shape, convert0));
+
+  HloInstruction* add1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c));
+  builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(FoldConversions(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), add1);
+  EXPECT_EQ(add0->shape().element_type(), BF16);
+  EXPECT_EQ(add1->shape().element_type(), BF16);
+  EXPECT_EQ(add1->operand(0), add0);
+}
+
+TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32_shape, "b"));
+  HloInstruction* c = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+  HloInstruction* mul0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b));
+  HloInstruction* convert0 =
+      builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0));
+  HloInstruction* convert1 = builder.AddInstruction(
+      HloInstruction::CreateConvert(f32_shape, convert0));
+
+  HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32_shape, HloOpcode::kMultiply, convert1, c));
+  HloInstruction* convert2 =
+      builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_FALSE(FoldConversions(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), convert2);
+  EXPECT_EQ(mul0->shape().element_type(), F32);
+  EXPECT_EQ(mul1->shape().element_type(), F32);
+  EXPECT_EQ(mul1->operand(0), convert1);
+}
+
+TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32_shape, "b"));
+  HloInstruction* c = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+  HloInstruction* sub0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b));
+  HloInstruction* convert0 =
+      builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0));
+  HloInstruction* convert1 = builder.AddInstruction(
+      HloInstruction::CreateConvert(f32_shape, convert0));
+
+  HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32_shape, HloOpcode::kSubtract, convert1, c));
+  HloInstruction* convert2 =
+      builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_FALSE(FoldConversions(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), convert2);
+  EXPECT_EQ(sub0->shape().element_type(), F32);
+  EXPECT_EQ(sub1->shape().element_type(), F32);
+  EXPECT_EQ(sub1->operand(0), convert1);
+}
+
+TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_shape, "b"));
+  HloInstruction* convert0 =
+      builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b));
+
+  HloInstruction* tuple =
+      builder.AddInstruction(HloInstruction::CreateTuple({a, convert0}));
+  HloInstruction* gte = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
+  HloInstruction* convert1 =
+      builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_FALSE(FoldConversions(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), convert1);
+  EXPECT_EQ(gte->shape().element_type(), F32);
+  EXPECT_EQ(tuple->operand(1), convert0);
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
new file mode 100644 (file)
index 0000000..b032c04
--- /dev/null
@@ -0,0 +1,351 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
+ public:
+  explicit BFloat16NormalizationVisitor(HloComputation* computation,
+                                        const BFloat16Support* bfloat16_support)
+      : computation_(computation), bfloat16_support_(bfloat16_support) {}
+
+  Status DefaultAction(HloInstruction* hlo) override;
+
+  // Special handling for cross-replica-sum which can have a tuple output.
+  Status HandleCrossReplicaSum(HloInstruction* crs) override;
+
+  static bool Run(HloComputation* computation,
+                  const BFloat16Support* bfloat16_support) {
+    BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
+    TF_CHECK_OK(computation->Accept(&visitor));
+    return visitor.changed_;
+  }
+
+ private:
+  // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
+  // conversions between F32 and BF16 to make it supported.
+  Status HandleInstruction(HloInstruction* hlo);
+
+  // Inserts a conversion HLO that changes the given HLO's output type.
+  Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
+                                  HloComputation* computation);
+
+  // Changes the output type to the specified type, then inserts a conversion
+  // to the original type.
+  Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
+                                               PrimitiveType to,
+                                               HloComputation* computation);
+
+  // Inserts a conversion HLO that changes the given HLO's operand type.
+  Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
+                                    PrimitiveType to,
+                                    HloComputation* computation);
+
+  // Inserts conversion HLOs to replace the called computations' BF16
+  // operands/outputs to F32.
+  Status ConvertCalledComputations(
+      HloInstruction* hlo,
+      tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
+
+  HloComputation* computation_;
+  const BFloat16Support* bfloat16_support_;
+  bool changed_ = false;
+};
+
+Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
+    HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
+  bool is_root = computation->root_instruction() == hlo;
+  std::vector<HloInstruction*> materialized_users = hlo->users();
+  // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
+  auto convert = computation->AddInstruction(
+      HloInstruction::CreateConvert(hlo->shape(), hlo));
+  for (auto* user : materialized_users) {
+    TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
+  }
+  if (is_root) {
+    computation->set_root_instruction(convert);
+  }
+  convert->mutable_shape()->set_element_type(to);
+  changed_ = true;
+  return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
+    HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
+  auto original_type = hlo->shape().element_type();
+  hlo->mutable_shape()->set_element_type(to);
+  return InsertConvertAfterOutput(hlo, original_type, computation);
+}
+
+Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
+    HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
+    HloComputation* computation) {
+  auto operand = hlo->mutable_operand(operand_idx);
+  auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
+      ShapeUtil::ChangeElementType(operand->shape(), to), operand));
+  TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
+  changed_ = true;
+  return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::ConvertCalledComputations(
+    HloInstruction* hlo,
+    tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
+  std::map<HloComputation*, HloComputation*> cloned_computations;
+  for (auto& comp : bf16_called_comps) {
+    auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
+    cloned_computations[comp] = cloned;
+    changed_ = true;
+  }
+  hlo->ReplaceCalledComputations([&](HloComputation* comp) {
+    auto it = cloned_computations.find(comp);
+    if (it != cloned_computations.end()) {
+      return it->second;
+    }
+    return comp;
+  });
+  for (auto& comp_pair : cloned_computations) {
+    auto comp = comp_pair.second;
+    if (comp->root_instruction()->shape().element_type() == BF16) {
+      TF_RETURN_IF_ERROR(
+          InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
+    }
+    for (auto* param : comp->parameter_instructions()) {
+      if (param->shape().element_type() == BF16) {
+        // This changes the parameter to F32 then inserts a convert after it.
+        TF_RETURN_IF_ERROR(
+            ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
+      }
+    }
+  }
+  return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
+    HloInstruction* crs) {
+  if (!ShapeUtil::IsTuple(crs->shape())) {
+    return HandleInstruction(crs);
+  }
+
+  std::vector<PrimitiveType> operand_types(crs->operand_count());
+  std::vector<PrimitiveType> output_types(crs->operand_count());
+  bool has_f32 = false;
+  bool has_bf16 = false;
+  bool has_bf16_output = false;
+  for (int64 i = 0; i < crs->operand_count(); ++i) {
+    operand_types[i] = crs->operand(i)->shape().element_type();
+    output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type();
+    if (operand_types[i] == F32 || output_types[i] == F32) {
+      has_f32 = true;
+    } else if (operand_types[i] == BF16) {
+      has_bf16 = true;
+    }
+    if (output_types[i] == BF16) {
+      has_bf16 = true;
+      has_bf16_output = true;
+    }
+  }
+
+  for (int64 i = 0; i < crs->operand_count(); ++i) {
+    if (operand_types[i] != BF16) {
+      continue;
+    }
+    if (bfloat16_support_->SupportsBF16Operand(*crs, i) &&
+        (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
+      continue;
+    }
+    TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
+    has_f32 = true;
+  }
+
+  if (!has_bf16_output) {
+    return Status::OK();
+  }
+
+  if (bfloat16_support_->SupportsBF16Output(*crs) &&
+      (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
+    return Status::OK();
+  }
+
+  std::vector<HloInstruction*> output_elements(crs->operand_count());
+  auto original_shape = crs->shape();
+  for (int64 i = 0; i < crs->operand_count(); ++i) {
+    auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i});
+    if (output_types[i] != BF16) {
+      output_elements[i] = computation_->AddInstruction(
+          HloInstruction::CreateGetTupleElement(*subshape, crs, i));
+      continue;
+    }
+    subshape->set_element_type(F32);
+    auto gte = computation_->AddInstruction(
+        HloInstruction::CreateGetTupleElement(*subshape, crs, i));
+    output_elements[i] =
+        computation_->AddInstruction(HloInstruction::CreateConvert(
+            ShapeUtil::ChangeElementType(*subshape, BF16), gte));
+  }
+  auto tuple = computation_->AddInstruction(
+      HloInstruction::CreateTuple(output_elements));
+
+  std::vector<HloInstruction*> materialized_users = crs->users();
+  // Use the crs' shape temporarily, in order to pass checks in
+  // ReplaceUseWith.
+  *tuple->mutable_shape() = crs->shape();
+  for (auto* user : materialized_users) {
+    TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple));
+  }
+  *tuple->mutable_shape() = original_shape;
+  return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
+  std::vector<int64> bf16_operands;
+  std::vector<int64> f32_operands;
+  bool has_f32 = false;
+  bool has_bf16 = false;
+
+  for (int64 i = 0; i < hlo->operand_count(); ++i) {
+    if (hlo->operand(i)->shape().element_type() == F32) {
+      f32_operands.push_back(i);
+      has_f32 = true;
+    } else if (hlo->operand(i)->shape().element_type() == BF16) {
+      bf16_operands.push_back(i);
+      has_bf16 = true;
+    }
+  }
+
+  if (hlo->shape().element_type() == F32) {
+    has_f32 = true;
+  } else if (hlo->shape().element_type() == BF16) {
+    has_bf16 = true;
+  }
+
+  std::vector<HloComputation*> bf16_called_comps;
+  for (auto* comp : hlo->called_computations()) {
+    bool comp_has_bf16 = false;
+    if (comp->root_instruction()->shape().element_type() == F32) {
+      has_f32 = true;
+    } else if (comp->root_instruction()->shape().element_type() == BF16) {
+      has_bf16 = true;
+      comp_has_bf16 = true;
+    }
+    for (auto* param : comp->parameter_instructions()) {
+      if (param->shape().element_type() == F32) {
+        has_f32 = true;
+      } else if (param->shape().element_type() == BF16) {
+        has_bf16 = true;
+        comp_has_bf16 = true;
+      }
+    }
+    if (comp_has_bf16) {
+      bf16_called_comps.push_back(comp);
+    }
+  }
+
+  if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 &&
+      has_f32) {
+    // Resolve unsupported mixed precision.
+    //
+    // See if we can change everything to BF16.
+    if (hlo->called_computations().empty() &&
+        hlo->shape().element_type() == BF16) {
+      bool can_use_bf16 = true;
+      for (int i : f32_operands) {
+        if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
+                                                                          i) &&
+            bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
+          continue;
+        }
+        can_use_bf16 = false;
+        break;
+      }
+      if (can_use_bf16) {
+        for (int i : f32_operands) {
+          TF_RETURN_IF_ERROR(
+              InsertConvertBeforeOperand(hlo, i, BF16, computation_));
+        }
+        return Status::OK();
+      }
+    }
+    if (hlo->shape().element_type() == BF16) {
+      TF_RETURN_IF_ERROR(
+          ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
+    }
+    for (int i : bf16_operands) {
+      TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
+    }
+    return ConvertCalledComputations(hlo, bf16_called_comps);
+  }
+
+  for (int i : bf16_operands) {
+    if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
+      TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
+    }
+  }
+
+  if (hlo->shape().element_type() == BF16 &&
+      !bfloat16_support_->SupportsBF16Output(*hlo)) {
+    TF_RETURN_IF_ERROR(
+        ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
+  }
+
+  return Status::OK();
+}
+
+Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
+  // Do not change instructions related to entry and exit of a computation,
+  // tuples, fusion, convert, and control flow.
+  if (hlo->opcode() == HloOpcode::kTuple ||            //
+      hlo->opcode() == HloOpcode::kGetTupleElement ||  //
+      hlo->opcode() == HloOpcode::kInfeed ||           //
+      hlo->opcode() == HloOpcode::kOutfeed ||          //
+      hlo->opcode() == HloOpcode::kConstant ||         //
+      hlo->opcode() == HloOpcode::kParameter ||        //
+      hlo->opcode() == HloOpcode::kFusion ||           //
+      hlo->opcode() == HloOpcode::kConvert ||          //
+      hlo->opcode() == HloOpcode::kCall ||             //
+      hlo->opcode() == HloOpcode::kCustomCall ||       //
+      hlo->opcode() == HloOpcode::kWhile ||            //
+      hlo->opcode() == HloOpcode::kConditional) {
+    return Status::OK();
+  }
+  return HandleInstruction(hlo);
+}
+
+StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
+  XLA_VLOG_LINES(
+      2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  for (auto* comp : module->MakeComputationPostOrder()) {
+    if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
+      changed = true;
+    }
+  }
+  XLA_VLOG_LINES(2,
+                 "BFloat16Normalization::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
new file mode 100644 (file)
index 0000000..2a60fe0
--- /dev/null
@@ -0,0 +1,92 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
+
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
+// support BF16 input/output or mixed precision, according to the passed-in
+// backend-specific BF16 support rules.
+class BFloat16Normalization : public HloPassInterface {
+ public:
+  explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
+      : bfloat16_support_(bfloat16_support) {}
+
+  ~BFloat16Normalization() override = default;
+  tensorflow::StringPiece name() const override { return "bf16-normalization"; }
+
+  // Run BF16 normalization on the given computation. Returns whether the
+  // computation was changed.
+  StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+  const BFloat16Support* bfloat16_support_;
+};
+
+// A pass that unconditionally removes the mixed F32/BF16 uses in HLO
+// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike
+// BFloat16Normalization, this pass does not use a backend-specific
+// BFloat16Support, and does not change HLOs that have BF16 data if they do not
+// use mixed precision; it removes mixed precision even if the backend supports
+// it. This pass is used to make the HLO module valid for other HLO passes which
+// do not support mixed precision.
+class BFloat16MixedPrecisionRemoval : public HloPassInterface {
+ public:
+  BFloat16MixedPrecisionRemoval() {}
+
+  ~BFloat16MixedPrecisionRemoval() override = default;
+
+  tensorflow::StringPiece name() const override {
+    return "bf16-mixed-precision-removal";
+  }
+
+  // Run mixed precision removal on the given computation. Returns whether the
+  // computation was changed.
+  StatusOr<bool> Run(HloModule* module) override {
+    BFloat16Normalization normalization(&no_mixed_precision_support_);
+    return normalization.Run(module);
+  }
+
+ private:
+  class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support {
+   public:
+    BFloat16SupportForMixedPrecisionRemoval() {}
+
+    ~BFloat16SupportForMixedPrecisionRemoval() override = default;
+
+    bool SupportsBF16Operand(const HloInstruction& hlo,
+                             int64 operand_index) const override {
+      return true;
+    }
+
+    bool SupportsBF16Output(const HloInstruction& hlo) const override {
+      return true;
+    }
+
+    bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
+      return false;
+    }
+  } no_mixed_precision_support_;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
new file mode 100644 (file)
index 0000000..66c3085
--- /dev/null
@@ -0,0 +1,248 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+class TestBFloat16Support : public BFloat16Support {
+ public:
+  TestBFloat16Support() {}
+  ~TestBFloat16Support() override {}
+
+  bool SupportsBF16Operand(const HloInstruction& hlo,
+                           int64 operand_index) const override {
+    if (hlo.opcode() == HloOpcode::kAdd ||
+        hlo.opcode() == HloOpcode::kSubtract ||
+        hlo.opcode() == HloOpcode::kReduce ||
+        hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement) {
+      return true;
+    }
+    return false;
+  }
+
+  bool SupportsBF16Output(const HloInstruction& hlo) const override {
+    if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
+        hlo.opcode() == HloOpcode::kSubtract ||
+        hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement) {
+      return true;
+    }
+    return false;
+  }
+
+  bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
+    if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement) {
+      return true;
+    }
+    return false;
+  }
+};
+
+class BFloat16NormalizationTest : public HloTestBase {
+ protected:
+  bool Normalize(HloModule* module) {
+    TestBFloat16Support bfloat16_support_;
+    BFloat16Normalization normalization(&bfloat16_support_);
+    StatusOr<bool> result = normalization.Run(module);
+    EXPECT_IS_OK(result.status());
+    return result.ValueOrDie();
+  }
+};
+
+TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_shape, "b"));
+  HloInstruction* c = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+  HloInstruction* add0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b));
+
+  HloInstruction* add1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_FALSE(Normalize(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), add1);
+  EXPECT_EQ(add0->shape().element_type(), BF16);
+  EXPECT_EQ(add1->shape().element_type(), F32);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_shape, "b"));
+  HloInstruction* c = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+  HloInstruction* mul0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b));
+
+  HloInstruction* mul1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(Normalize(module.get()));
+
+  EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
+  EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
+  EXPECT_EQ(mul0->shape().element_type(), F32);
+  EXPECT_EQ(mul1->shape().element_type(), F32);
+  EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_shape, "b"));
+  HloInstruction* c = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32_shape, "c"));
+
+  HloInstruction* sub0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b));
+
+  HloInstruction* sub1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(Normalize(module.get()));
+
+  EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
+  EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
+  EXPECT_EQ(sub0->shape().element_type(), F32);
+  EXPECT_EQ(sub1->shape().element_type(), F32);
+  EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
+  Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
+
+  Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
+  auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0"));
+  auto reduce_comp_param1 = reduce_comp_builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1"));
+  reduce_comp_builder.AddInstruction(
+      HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd,
+                                   reduce_comp_param0, reduce_comp_param1));
+
+  auto module = CreateNewModule();
+  auto reduce_computation =
+      module->AddEmbeddedComputation(reduce_comp_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* input = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_input_shape, "a"));
+  HloInstruction* init = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_scalar_shape, "init"));
+  HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
+      f32_output_shape, input, init, {0}, reduce_computation));
+
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(Normalize(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), reduce);
+  EXPECT_EQ(reduce->called_computations().size(), 1);
+  EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2);
+  EXPECT_EQ(reduce->called_computations()[0]
+                ->parameter_instruction(0)
+                ->shape()
+                .element_type(),
+            F32);
+  EXPECT_EQ(reduce->called_computations()[0]
+                ->parameter_instruction(1)
+                ->shape()
+                .element_type(),
+            F32);
+  EXPECT_EQ(reduce->called_computations()[0]
+                ->root_instruction()
+                ->shape()
+                .element_type(),
+            F32);
+  EXPECT_EQ(reduce->shape().element_type(), F32);
+  EXPECT_EQ(reduce->operand(0), input);
+  EXPECT_EQ(input->shape().element_type(), F32);
+  EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert);
+  EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32);
+}
+
+TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32_shape, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_shape, "b"));
+
+  HloInstruction* crs =
+      builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
+          ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
+  HloInstruction* gte = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(Normalize(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), gte);
+  EXPECT_EQ(gte->shape().element_type(), BF16);
+  EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
+  EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc
new file mode 100644 (file)
index 0000000..3fd9e24
--- /dev/null
@@ -0,0 +1,111 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+namespace xla {
+
+bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
+                                          int64 operand_index) const {
+  switch (hlo.opcode()) {
+    case HloOpcode::kCall:
+    case HloOpcode::kConditional:
+    case HloOpcode::kCustomCall:
+    case HloOpcode::kGetTupleElement:
+    case HloOpcode::kTuple:
+    case HloOpcode::kWhile:
+      return true;
+    case HloOpcode::kConvert:
+      CHECK_EQ(operand_index, 0);
+      return hlo.operand(0)->shape().element_type() == BF16;
+    default:
+      break;
+  }
+  return false;
+}
+
+bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
+  switch (hlo.opcode()) {
+    case HloOpcode::kCall:
+    case HloOpcode::kConditional:
+    case HloOpcode::kCustomCall:
+    case HloOpcode::kGetTupleElement:
+    case HloOpcode::kTuple:
+    case HloOpcode::kWhile:
+      return true;
+    case HloOpcode::kConvert:
+      return hlo.shape().element_type() == BF16;
+    default:
+      break;
+  }
+  return false;
+}
+
+bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
+  switch (hlo.opcode()) {
+    case HloOpcode::kCall:
+    case HloOpcode::kConditional:
+    case HloOpcode::kConvert:
+    case HloOpcode::kCustomCall:
+    case HloOpcode::kGetTupleElement:
+    case HloOpcode::kTuple:
+    case HloOpcode::kWhile:
+      return true;
+    default:
+      break;
+  }
+  return false;
+}
+
+/* static */
+bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
+    const HloInstruction& hlo, int64 operand_index) {
+  switch (hlo.opcode()) {
+    case HloOpcode::kAbs:
+    case HloOpcode::kBroadcast:
+    case HloOpcode::kClamp:
+    case HloOpcode::kConcatenate:
+    case HloOpcode::kCopy:
+    case HloOpcode::kGetTupleElement:
+    case HloOpcode::kMaximum:
+    case HloOpcode::kMinimum:
+    case HloOpcode::kPad:
+    case HloOpcode::kReshape:
+    case HloOpcode::kReverse:
+    case HloOpcode::kSlice:
+    case HloOpcode::kSort:
+    case HloOpcode::kTranspose:
+    case HloOpcode::kTuple:
+      return true;
+    case HloOpcode::kDynamicSlice:
+      return operand_index == 0;
+    case HloOpcode::kDynamicUpdateSlice:
+      return operand_index == 0 || operand_index == 1;
+    case HloOpcode::kSelect:
+      return operand_index == 1 || operand_index == 2;
+    default:
+      break;
+  }
+  return false;
+}
+
+bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
+    const HloInstruction& hlo, int64 operand_index) const {
+  return false;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h
new file mode 100644 (file)
index 0000000..29f662d
--- /dev/null
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+namespace xla {
+
+class BFloat16Support {
+ public:
+  BFloat16Support() {}
+  virtual ~BFloat16Support() {}
+
+  // Returns whether the backend supports BF16 operand for the HLO instruction
+  // at the given index.
+  virtual bool SupportsBF16Operand(const HloInstruction& hlo,
+                                   int64 operand_index) const;
+
+  // Returns whether the backend supports BF16 output for the HLO instruction.
+  virtual bool SupportsBF16Output(const HloInstruction& hlo) const;
+
+  // Returns whether the backend support mixed precision: the operands, output,
+  // and parameters/output of the called computations can have different
+  // precisions (BF16 and F32).
+  virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const;
+
+  // Returns whether the given HLO inherits its BF16 operand precision at the
+  // given index, so even if the output is F32, elements in the output that
+  // depend on the BF16 operand will still have BF16 effective precision even if
+  // they have F32 format. Similarly, this also means if the output is BF16 then
+  // increasing the operand precision from BF16 to F32 will not change the
+  // output. This typically includes HLOs that pass elements from the operand to
+  // the output without arithmetic operations.
+  static bool EffectiveOperandPrecisionIsOutputPrecision(
+      const HloInstruction& hlo, int64 operand_index);
+
+  // Returns if the backend only uses BF16 precision for the operand at the
+  // specified index, even if the operand is F32.
+  virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
+                                               int64 operand_index) const;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
index 0e4437b..0981f1f 100644 (file)
@@ -1805,7 +1805,8 @@ void HloInstruction::RemoveUser(HloInstruction* user) {
 
 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
                                       HloInstruction* new_producer) {
-  TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape()))
+  TF_RET_CHECK(
+      ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
       << "this shape: " << ShapeUtil::HumanString(shape())
       << ", replacement shape: "
       << ShapeUtil::HumanString(new_producer->shape());
@@ -1828,8 +1829,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
   TF_RET_CHECK(operand_num >= 0);
   TF_RET_CHECK(operand_num < operand_count());
   HloInstruction* old_operand = mutable_operand(operand_num);
-  TF_RET_CHECK(
-      ShapeUtil::Compatible(old_operand->shape(), new_operand->shape()))
+  TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
+                                                        new_operand->shape()))
       << old_operand->shape().ShortDebugString() << " is not compatible with "
       << new_operand->shape().ShortDebugString();
   operands_[operand_num] = new_operand;
index 04d4656..e2b3bb9 100644 (file)
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include <set>
+
 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -164,6 +166,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
   // HLO broadcast has no exact analog at the proto level so there is no
   // ShapeInference method. Check the output shape explicitly.
   const Shape& operand_shape = broadcast->operand(0)->shape();
+  // Check for mixed precision.
+  TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape()));
   TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
                broadcast->dimensions().size());
   for (int64 operand_dimension = 0;
@@ -178,6 +182,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
 }
 
 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
+  // Check for mixed precision.
+  TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
   TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
                ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
   return tensorflow::Status::OK();
@@ -359,13 +365,122 @@ Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
                                          batch_norm_grad->feature_index()));
 }
 
+namespace {
+
+// Checks that the instruction does not have mixed precision floating point
+// inputs.
+Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
+  switch (instruction->opcode()) {
+    // White list the following opcodes for mixed-precision check, because they
+    // involve data pass through or grouping via tuples, where the precisions
+    // of buffers can be different.
+    case HloOpcode::kCall:
+    case HloOpcode::kConditional:
+    case HloOpcode::kConstant:
+    case HloOpcode::kCrossReplicaSum:
+    case HloOpcode::kCustomCall:
+    case HloOpcode::kFusion:
+    case HloOpcode::kGetTupleElement:
+    case HloOpcode::kInfeed:
+    case HloOpcode::kOutfeed:
+    case HloOpcode::kParameter:
+    case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
+    case HloOpcode::kReducePrecision:
+    case HloOpcode::kSelect:
+    case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
+    case HloOpcode::kTuple:
+    case HloOpcode::kWhile:
+      break;
+    default: {
+      PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
+      for (auto operand : instruction->operands()) {
+        TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+            operand->shape(),
+            [&](const Shape& subshape, const ShapeIndex& index) {
+              if (!ShapeUtil::ElementIsFloating(subshape)) {
+                return Status::OK();
+              }
+              if (fp_type == PRIMITIVE_TYPE_INVALID) {
+                fp_type = subshape.element_type();
+              } else if (fp_type != subshape.element_type()) {
+                return FailedPrecondition(
+                    "Seen floating point types of different precisions in "
+                    "%s, but mixed precision is disallowed.",
+                    instruction->ToString().c_str());
+              }
+              return Status::OK();
+            }));
+      }
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
-                                 const Shape& expected_shape) {
-  if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) {
+                                 const Shape& inferred_shape) {
+  // If allow_mixed_precision_ is false, check if there are operands with
+  // different precisions. We need this check because ShapeInference allows
+  // mixed precision inputs.
+  if (!allow_mixed_precision_) {
+    TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
+  }
+
+  // Check if the output shape matches the expected shape.
+  bool compatible;
+  // We treat BF16 and F32 as compatible types if mixed precision is allowed,
+  // but only when the instruction defines the BF16/F32 buffer.
+  switch (instruction->opcode()) {
+    case HloOpcode::kSelect:
+      if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) {
+        // Select only defines the top-level buffer, which in this case is the
+        // tuple, so we cannot allow mixed precision.
+        compatible =
+            ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+      } else {
+        compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
+            instruction->shape(), inferred_shape);
+      }
+      break;
+    case HloOpcode::kGetTupleElement:
+    case HloOpcode::kTuple:
+      // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
+      // precision is disallowed.
+    case HloOpcode::kConstant:
+    case HloOpcode::kBitcast:
+    case HloOpcode::kBitcastConvert:
+    case HloOpcode::kCall:
+    case HloOpcode::kConditional:
+    case HloOpcode::kConvert:
+    case HloOpcode::kCustomCall:
+    case HloOpcode::kInfeed:
+    case HloOpcode::kOutfeed:
+    case HloOpcode::kParameter:
+    case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
+    case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
+    case HloOpcode::kWhile:
+      // The above opcodes should match the expected shapes exactly.
+      compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+      break;
+    default:
+      if (allow_mixed_precision_) {
+        compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
+            instruction->shape(), inferred_shape);
+      } else {
+        compatible =
+            ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+      }
+  }
+  if (!compatible) {
     return InvalidArgument(
         "Expected instruction to have shape compatible with %s, actual "
         "shape is %s:\n%s",
-        ShapeUtil::HumanString(expected_shape).c_str(),
+        ShapeUtil::HumanString(inferred_shape).c_str(),
         ShapeUtil::HumanString(instruction->shape()).c_str(),
         instruction->ToString().c_str());
   }
@@ -373,14 +488,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
 }
 
 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
-                                 const StatusOr<Shape>& expected_shape_status) {
-  if (!expected_shape_status.ok()) {
-    Status s = expected_shape_status.status();
+                                 const StatusOr<Shape>& inferred_shape_status) {
+  if (!inferred_shape_status.ok()) {
+    Status s = inferred_shape_status.status();
     tensorflow::errors::AppendToMessage(&s, ", for instruction ",
                                         instruction->ToString());
     return s;
   }
-  return CheckShape(instruction, expected_shape_status.ValueOrDie());
+  return CheckShape(instruction, inferred_shape_status.ValueOrDie());
 }
 
 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
index 26d53de..7eccf83 100644 (file)
@@ -27,6 +27,10 @@ namespace xla {
 // TODO(b/26024837): Check output shape for all instruction types.
 class ShapeVerifier : public DfsHloVisitor {
  public:
+  explicit ShapeVerifier() : allow_mixed_precision_(false) {}
+  explicit ShapeVerifier(bool allow_mixed_precision)
+      : allow_mixed_precision_(allow_mixed_precision) {}
+
   Status HandleElementwiseUnary(HloInstruction* hlo) override;
   Status HandleElementwiseBinary(HloInstruction* hlo) override;
   Status HandleClamp(HloInstruction* clamp) override;
@@ -81,14 +85,14 @@ class ShapeVerifier : public DfsHloVisitor {
   }
 
  protected:
-  // Check the instruction's shape against the given expected shape and return
-  // an appropriate error if there is a mismatch.
+  // Check the instruction's shape against the shape given by ShapeInference
+  // and return an appropriate error if there is a mismatch.
   Status CheckShape(const HloInstruction* instruction,
-                    const Shape& expected_shape);
+                    const Shape& inferred_shape);
 
   // Overload which takes a StatusOr to reduce boilerplate in the caller.
   Status CheckShape(const HloInstruction* instruction,
-                    const StatusOr<Shape>& expected_shape_status);
+                    const StatusOr<Shape>& inferred_shape_status);
 
   // Check a unary (binary, etc) instruction's shape against the inferred shape.
   Status CheckUnaryShape(const HloInstruction* instruction);
@@ -99,19 +103,32 @@ class ShapeVerifier : public DfsHloVisitor {
   // Checks if the given two instructions shares the same channel id.
   Status CheckSameChannel(const HloInstruction* instr1,
                           const HloInstruction* instr2);
+
+ private:
+  // Whether the inputs and output of an instruction can contain both F32s and
+  // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
+  // this flag.
+  bool allow_mixed_precision_;
 };
 
 // HLO pass that verifies invariants of HLO instructions for each computation in
 // the module.
 class HloVerifier : public HloPassInterface {
  public:
+  using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
+
   // Uses standard shape inference.
   explicit HloVerifier()
-      : shape_verifier_factory_([] { return MakeUnique<ShapeVerifier>(); }) {}
+      : shape_verifier_factory_(
+            [] { return MakeUnique<ShapeVerifier>(false); }) {}
+
+  explicit HloVerifier(bool allow_mixed_precision)
+      : shape_verifier_factory_([allow_mixed_precision] {
+          return MakeUnique<ShapeVerifier>(allow_mixed_precision);
+        }) {}
 
   // Uses custom shape verification.
-  explicit HloVerifier(
-      std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory)
+  explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
       : shape_verifier_factory_(std::move(shape_verifier_factory)) {}
 
   ~HloVerifier() override = default;
@@ -129,7 +146,7 @@ class HloVerifier : public HloPassInterface {
   // expectations.  This is a factory function because ShapeVerifier,  Note that
   // ShapeVerifier, being a DfsHloVisitor, is stateful.  We want a clean object
   // for each run of the verifier.
-  std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory_;
+  ShapeVerifierFactory shape_verifier_factory_;
 };
 
 }  // namespace xla
index 4ba6da6..004889b 100644 (file)
@@ -209,7 +209,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
   }
 
   // Check that init_value's shape is suitable for reducer_shape.
-  if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) {
+  if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
+                                                init_value_shape)) {
     return InvalidArgument(
         "Reduction function's accumulator shape differs from the "
         "init_value shape: %s vs %s",
@@ -220,8 +221,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
   // Check that the inputs can be passed in as the second argument.
   const Shape& input_element_shape =
       ShapeUtil::MakeShape(input_element_type, {});
-  if (!ShapeUtil::Compatible(input_element_shape,
-                             reducer_shape.parameters(1))) {
+  if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
+                                                reducer_shape.parameters(1))) {
     return InvalidArgument(
         "Reduction function's second parameter shape differs from the "
         "input type element type: %s vs %s",
@@ -231,7 +232,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
 
   // Currently the accumulator and inputs must be the same type,
   // though that restriction could be relaxed.
-  if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) {
+  if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
+                                                reducer_shape.parameters(1))) {
     return InvalidArgument(
         "Reduction function's second parameter shape currently must "
         "match the result shape. Got %s vs %s",
@@ -394,11 +396,13 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
                            dimension);
   }
   const Shape* arg_shape = nullptr;
+  PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
   for (const Shape* shape : arg_shapes) {
     TF_RETURN_IF_ERROR(
         ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
     if (!arg_shape) {
       arg_shape = shape;
+      element_type = arg_shape->element_type();
       continue;
     }
     if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
@@ -409,7 +413,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
           ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
           ShapeUtil::HumanString(*shape).c_str());
     }
-    if (arg_shape->element_type() != shape->element_type()) {
+    if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
       return InvalidArgument(
           "cannot concatenate arrays with different element types: %s vs %s",
           PrimitiveType_Name(arg_shape->element_type()).c_str(),
@@ -431,6 +435,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
             ShapeUtil::HumanString(*shape).c_str(), dimension);
       }
     }
+    element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
   }
 
   std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
@@ -438,7 +443,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
   for (size_t i = 1; i < arg_shapes.size(); ++i) {
     new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
   }
-  return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions);
+  return ShapeUtil::MakeShape(element_type, new_dimensions);
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
@@ -536,7 +541,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
         ShapeUtil::HumanString(operand_shape).c_str(),
         padding_config.ShortDebugString().c_str());
   }
-  if (operand_shape.element_type() != padding_value_shape.element_type()) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
+                                                     padding_value_shape)) {
     return InvalidArgument(
         "the element types of the operands to pad do not match");
   }
@@ -548,7 +554,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
                     std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
                         padding_config.dimensions(i).interior_padding();
   }
-  return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
+  return ShapeUtil::MakeShape(
+      ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
+      dimensions);
 }
 
 // Current DotDimensionNumbers Requirements:
@@ -673,7 +681,7 @@ Status ValidateDotDimensionNumbers(
   };
 
   // Check if both element types are the same.
-  if (lhs.element_type() != rhs.element_type()) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
     return fail("element types do not match");
   }
 
@@ -736,7 +744,8 @@ Status ValidateDotDimensionNumbers(
       dimensions.push_back(rhs.dimensions(i));
     }
   }
-  Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+  Shape result = ShapeUtil::MakeShape(
+      ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions);
 
   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
@@ -767,7 +776,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
                              ShapeUtil::HumanString(rhs).c_str());
     }
   }
-  return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions);
+  return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
+                              output_dimensions);
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
@@ -829,6 +839,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
   // specified in broadcast_dimensions are then changed to match the
   // corresponding dimension size in smaller_shape.
   Shape output_shape(larger_shape);
+  output_shape.set_element_type(
+      ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
 
   for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
     int64 dimension_to_match = broadcast_dimensions.at(i);
@@ -878,7 +890,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
   TF_RETURN_IF_ERROR(
       ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
 
-  if (!ShapeUtil::SameElementType(lhs, rhs)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
     return InvalidArgument(
         "binary op %s with different element types: %s and %s",
         BinaryOperation_Name(operation).c_str(),
@@ -897,10 +909,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
     }
   }
 
-  if (ShapeUtil::Compatible(lhs, rhs)) {
+  if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
     // If the shapes are the same other than layout, the output shape is the
     // same (elementwise op).
-    return lhs;
+    return ShapeUtil::ChangeElementType(
+        lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
   }
 
   if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
@@ -973,7 +986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
       TF_ASSIGN_OR_RETURN(const Shape& shape,
                           InferElementwiseBinaryOpShape(operation, lhs, rhs,
                                                         broadcast_dimensions));
-      if (lhs.element_type() == F32) {
+      if (lhs.element_type() == F32 && rhs.element_type() == F32) {
         return ShapeUtil::ChangeElementType(shape, C64);
       } else {
         return Unimplemented("complex component type not supported");
@@ -1078,12 +1091,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
     TF_RETURN_IF_ERROR(
         ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
 
-    if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) {
+    if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
       continue;
     }
     if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
         !ShapeUtil::IsTuple(*arg_shape) &&
-        ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) {
+        ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
+                                                      *arg_shape)) {
       if (ShapeUtil::IsScalar(*arg_shapes[i])) {
         continue;
       }
@@ -1148,7 +1162,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
           i, ShapeUtil::HumanString(parameter_shape).c_str());
     }
 
-    if (parameter_shape.element_type() != arg_shape->element_type()) {
+    if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
+                                                       *arg_shape)) {
       return InvalidArgument(
           "mapped computation's parameter type has to match argument element "
           "type; got parameter %d shape: %s, argument shape: %s",
@@ -1221,7 +1236,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for batch-norm-training, "
         "but the shape of offset factor is %s "
@@ -1230,7 +1246,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for batch-norm-training, "
         "but the shape of scale factor is %s "
@@ -1329,7 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for "
         "batch-norm-inference, "
@@ -1339,7 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for "
         "batch-norm-inference, "
@@ -1349,7 +1368,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for "
         "batch-norm-inference, "
@@ -1359,7 +1379,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for "
         "batch-norm-inference, "
@@ -1481,7 +1502,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(output_grad_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for batch-norm-grad, "
         "but the element type of output_grad is %s "
@@ -1490,7 +1512,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for batch-norm-grad, "
         "but the element type of scale factor is %s "
@@ -1499,7 +1522,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for batch-norm-grad, "
         "but the element type of mean is %s "
@@ -1508,7 +1532,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         PrimitiveType_Name(operand_shape.element_type()).c_str());
   }
 
-  if (!ShapeUtil::SameElementType(var_shape, operand_shape)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
+                                                     operand_shape)) {
     return InvalidArgument(
         "The inputs should have the same element type for batch-norm-grad, "
         "but the element type of mean is %s "
@@ -1569,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
 
-  if (!ShapeUtil::SameElementType(lhs, rhs)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
     return InvalidArgument(
         "Convolution with different element types: %s and %s",
         ShapeUtil::HumanString(lhs).c_str(),
@@ -1714,8 +1739,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
     dimensions[dnums.output_spatial_dimensions(i)] =
         window_output_shape.dimensions(i);
   }
-
-  return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+  return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
+                              dimensions);
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
@@ -1877,16 +1902,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
   }
   const Shape& operand_element_shape =
       ShapeUtil::MakeShape(operand_shape.element_type(), {});
-  if (!ShapeUtil::Compatible(operand_element_shape,
-                             select_shape.parameters(0))) {
+  if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
+                                                select_shape.parameters(0))) {
     return InvalidArgument(
         "select function's first parameter shape currently must "
         "match the operand element shape. Got %s vs %s",
         ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
         ShapeUtil::HumanString(operand_element_shape).c_str());
   }
-  if (!ShapeUtil::Compatible(operand_element_shape,
-                             select_shape.parameters(1))) {
+  if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
+                                                select_shape.parameters(1))) {
     return InvalidArgument(
         "select function's second parameter shape currently must "
         "match the operand element shape. Got %s vs %s",
@@ -1903,7 +1928,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
                       InferWindowOutputShape(operand_shape, window,
                                              operand_shape.element_type(),
                                              /*allow_negative_padding=*/false));
-  if (!ShapeUtil::Compatible(source_shape, window_result_shape)) {
+  if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
+                                                window_result_shape)) {
     return InvalidArgument(
         "source shape does not match the shape of window-reduced operand: "
         "source(%s), window-reduced operand(%s)",
@@ -2086,7 +2112,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
         ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
   }
 
-  if (operand_shape.element_type() != update_shape.element_type()) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
+                                                     update_shape)) {
     return InvalidArgument(
         "dynamic update slice update element type does not match argument. "
         "operand.element_type: %s vs update.element_type: %s",
@@ -2322,24 +2349,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
-  if (!ShapeUtil::SameElementType(min, operand) ||
-      !ShapeUtil::SameElementType(max, operand)) {
+  if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
+      !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
     return InvalidArgument("clamp op with different operand types: %s, %s, %s",
                            ShapeUtil::HumanString(min).c_str(),
                            ShapeUtil::HumanString(operand).c_str(),
                            ShapeUtil::HumanString(max).c_str());
   }
-  if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
-       (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
+  if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
+        ShapeUtil::IsScalar(min)) &&
+       (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
+        ShapeUtil::IsScalar(max)))) {
     return operand;
   }
   if (ShapeUtil::IsScalar(operand)) {
-    if (ShapeUtil::Compatible(min, max)) {
-      return min;
+    if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
+      return ShapeUtil::ChangeElementType(min, operand.element_type());
     } else if (ShapeUtil::IsScalar(min)) {
-      return max;
+      return ShapeUtil::ChangeElementType(max, operand.element_type());
     } else if (ShapeUtil::IsScalar(max)) {
-      return min;
+      return ShapeUtil::ChangeElementType(min, operand.element_type());
     }
   }
   return Unimplemented(
@@ -2352,7 +2381,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
 // broadcast from all operands, not just the predicate.
 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
     const Shape& pred, const Shape& on_true, const Shape& on_false) {
-  if (!ShapeUtil::Compatible(on_true, on_false)) {
+  bool compatible;
+  if (ShapeUtil::IsTuple(on_true)) {
+    // Select only defines the top-level buffer, so if it's a tuple, the two
+    // input must match exactly.
+    compatible = ShapeUtil::Compatible(on_true, on_false);
+  } else {
+    compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
+  }
+  if (!compatible) {
     return InvalidArgument(
         "operands to select must be the same shape; got %s and %s",
         ShapeUtil::HumanString(on_true).c_str(),
@@ -2367,7 +2404,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
     // By this stage we know that pred's element type is PRED. Therefore, this
     // check restricts pred to be a PRED scalar, or a PRED array with the same
     // dimensions as on_true and on_false.
-    return on_true;
+    return ShapeUtil::ChangeElementType(
+        on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
   } else {
     return Unimplemented(
         "select operation with non-scalar predicate with dimensionality "
index d63e16c..604e017 100644 (file)
@@ -630,6 +630,19 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
   return SameDimensions(lhs, rhs);
 }
 
+/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
+                                                           const Shape& rhs) {
+  if (lhs.element_type() == TUPLE) {
+    return rhs.element_type() == TUPLE &&
+           ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
+                           CompatibleIgnoringFpPrecision);
+  }
+  if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
+    return CompatibleIgnoringElementType(lhs, rhs);
+  }
+  return false;
+}
+
 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
                                            int64 dimension_number) {
   return shape.dimensions(GetDimensionNumber(shape, dimension_number));
index 453d4ec..d8a0088 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include <string>
 
 #include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -211,6 +212,31 @@ class ShapeUtil {
     return lhs.element_type() == rhs.element_type();
   }
 
+  // As SameElementType, but allows floating point types to have different
+  // precisions.
+  static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
+                                                 const Shape& b) {
+    if (ElementIsFloating(a) && ElementIsFloating(b)) {
+      return true;
+    }
+    return ShapeUtil::SameElementType(a, b);
+  }
+
+  // Returns the higher-precision element type if a and b are both floating
+  // point types; otherwise, checks that that they have the same element type
+  // and returns it.
+  static PrimitiveType HigherPrecisionElementType(const Shape& a,
+                                                  const Shape& b) {
+    if (SameElementType(a, b)) {
+      return a.element_type();
+    }
+    CHECK(SameElementTypeIgnoringFpPrecision(a, b));
+    return primitive_util::BitWidth(a.element_type()) <
+                   primitive_util::BitWidth(b.element_type())
+               ? b.element_type()
+               : a.element_type();
+  }
+
   // Returns true if the rank, dimension sizes, and element type are
   // identical. Layout is ignored. Tuple elements are compared recursively for
   // compatibility.
@@ -221,6 +247,10 @@ class ShapeUtil {
   // compatibility.
   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
 
+  // As Compatible, but allow one of lhs and rhs to be BF16 while the other
+  // being F32. Tuple elements are compared recursively for compatibility.
+  static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
+
   // Returns whether the lhs and rhs shapes are identical protobufs.
   static bool Equal(const Shape& lhs, const Shape& rhs);
 
index 81ba7af..4db97d4 100644 (file)
@@ -170,6 +170,18 @@ TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
   EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2));
 }
 
+TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) {
+  Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
+  Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
+  ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
+}
+
+TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) {
+  Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
+  Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2});
+  ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
+}
+
 TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
   Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
   Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2});
@@ -184,6 +196,14 @@ TEST(ShapeUtilTest, CompatibleTuples) {
   EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2));
 }
 
+TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) {
+  Shape tuple1 = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})});
+  Shape tuple2 = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
+  EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
+}
+
 TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
   Shape tuple1 = ShapeUtil::MakeTupleShape(
       {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
@@ -193,6 +213,14 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
   EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
 }
 
+TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) {
+  Shape tuple1 = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+  Shape tuple2 = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
+  EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
+}
+
 TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
   Shape tuple1 = ShapeUtil::MakeTupleShape(
       {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});