[XLA] Fix BF16 normalizer for CrossReplicaSum.
authorYuanzhong Xu <yuanzx@google.com>
Mon, 5 Mar 2018 20:54:27 +0000 (12:54 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 20:59:01 +0000 (12:59 -0800)
1. It may produce incorrect result when mixed precision is not supported and
BF16 is not support only for a particular operand. Then the pass may introduce
new mixed precision for an all-BF16 CRS. This is unlikely in practical
settings, but removing this constraint can enable auto-generating corner case
tests using this pass.

2. A cycle can be introduced in the tuple-shaped output output. This wasn't
caught by the test because the DFS happened to succeed. Now add verifier
explicitly.

PiperOrigin-RevId: 187908099

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/bfloat16_normalization.cc
tensorflow/compiler/xla/service/bfloat16_normalization_test.cc

index d71790f..6f52703 100644 (file)
@@ -106,6 +106,7 @@ tf_cc_test(
         ":bfloat16_normalization",
         ":bfloat16_support",
         ":hlo",
+        ":hlo_verifier",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:test",
index 6176f5d..14c54dd 100644 (file)
@@ -152,44 +152,64 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
 
   std::vector<PrimitiveType> operand_types(crs->operand_count());
   std::vector<PrimitiveType> output_types(crs->operand_count());
-  bool has_f32 = false;
-  bool has_bf16 = false;
-  bool has_bf16_output = false;
+  int64 f32_count = 0;
+  int64 bf16_count = 0;
+  bool has_unsupported_bf16_operand = false;
+  bool has_unsupported_bf16_output = false;
   for (int64 i = 0; i < crs->operand_count(); ++i) {
     operand_types[i] = crs->operand(i)->shape().element_type();
     output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type();
-    if (operand_types[i] == F32 || output_types[i] == F32) {
-      has_f32 = true;
+    if (operand_types[i] == F32) {
+      f32_count += 1;
     } else if (operand_types[i] == BF16) {
-      has_bf16 = true;
+      bf16_count += 1;
+      if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) {
+        has_unsupported_bf16_operand = true;
+      }
     }
-    if (output_types[i] == BF16) {
-      has_bf16 = true;
-      has_bf16_output = true;
+    if (output_types[i] == F32) {
+      f32_count += 1;
+    } else if (output_types[i] == BF16) {
+      bf16_count += 1;
+      if (!bfloat16_support_->SupportsBF16Output(*crs)) {
+        has_unsupported_bf16_output = true;
+      }
     }
   }
 
-  for (int64 i = 0; i < crs->operand_count(); ++i) {
+  if (bf16_count == 0) {
+    return Status::OK();
+  }
+
+  auto should_convert_operand = [&](int64 i) {
     if (operand_types[i] != BF16) {
-      continue;
+      return false;
     }
-    if (bfloat16_support_->SupportsBF16Operand(*crs, i) &&
-        (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
-      continue;
+    if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) {
+      return true;
     }
-    TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
-    has_f32 = true;
-  }
+    if (bfloat16_support_->SupportsMixedPrecisions(*crs)) {
+      return false;
+    }
+    return has_unsupported_bf16_operand || has_unsupported_bf16_output ||
+           f32_count > 0;
+  };
 
-  if (!has_bf16_output) {
-    return Status::OK();
+  for (int64 i = 0; i < crs->operand_count(); ++i) {
+    if (should_convert_operand(i)) {
+      TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
+      f32_count += 1;
+      bf16_count -= 1;
+    }
   }
 
-  if (bfloat16_support_->SupportsBF16Output(*crs) &&
-      (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
+  if (!has_unsupported_bf16_output &&
+      (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 ||
+       bf16_count == 0)) {
     return Status::OK();
   }
 
+  std::vector<HloInstruction*> materialized_users = crs->users();
   std::vector<HloInstruction*> output_elements(crs->operand_count());
   auto original_shape = crs->shape();
   for (int64 i = 0; i < crs->operand_count(); ++i) {
@@ -209,7 +229,6 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
   auto tuple = computation_->AddInstruction(
       HloInstruction::CreateTuple(output_elements));
 
-  std::vector<HloInstruction*> materialized_users = crs->users();
   // Use the crs' shape temporarily, in order to pass checks in
   // ReplaceUseWith.
   *tuple->mutable_shape() = crs->shape();
index fc0f6f1..1afaefd 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
@@ -74,6 +75,10 @@ class BFloat16NormalizationTest : public HloTestBase {
     BFloat16Normalization normalization(&bfloat16_support_);
     StatusOr<bool> result = normalization.Run(module);
     EXPECT_IS_OK(result.status());
+
+    HloVerifier verifier(/*allow_mixed_precision=*/true);
+    EXPECT_IS_OK(verifier.Run(module).status());
+
     return result.ValueOrDie();
   }
 };
@@ -170,7 +175,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
   Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
   Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
 
-  Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4});
+  Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {});
 
   auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
   auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
@@ -260,8 +265,11 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
   HloInstruction* b = builder.AddInstruction(
       HloInstruction::CreateParameter(1, bf16_shape, "b"));
 
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
   HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateBinary(bf16_shape, HloOpcode::kDot, a, b));
+      HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums));
 
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());