Improve our handling of bitcasts.
authorSanjoy Das <sanjoy@google.com>
Tue, 27 Feb 2018 19:57:09 +0000 (11:57 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Feb 2018 20:01:12 +0000 (12:01 -0800)
 - Do not fuse bitcasts in the CPU backend.  Fused instructions lose their
   layout and a bitcast is meaningless without a layout.  We were explicitly
   testing for this so I've changed the corresponding tests to use a reshape
   instead.
 - Fail the layout assignment if we see a bitcast.  bitcasts are inherently
   layout sensitive and so a bitcast instruction present in the IR before layout
   assignment is a bug.

PiperOrigin-RevId: 187210151

tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
tensorflow/compiler/xla/service/layout_assignment.cc
tensorflow/compiler/xla/service/layout_assignment_test.cc

index 482e04052d5a914eab0e5bff2c7a83f3b698052f..0fc5a746bbbc7685ff5d4647111a750e7d7b1c19 100644 (file)
@@ -30,7 +30,6 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
   // These are the only ones we fuse since we rely on effective elemental IR
   // generation.
   return hlo.IsElementwise() ||  //
-         hlo.opcode() == HloOpcode::kBitcast ||
          hlo.opcode() == HloOpcode::kBroadcast ||
          hlo.opcode() == HloOpcode::kConcatenate ||
          hlo.opcode() == HloOpcode::kDynamicSlice ||
index 595c3f55b321f47e2312b93e0c238c7637495d77..6ed1cd31b18f6360bdd7fd41bd5be2e657b310a5 100644 (file)
@@ -77,7 +77,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
   EXPECT_THAT(computation->root_instruction(), op::Fusion());
 }
 
-TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
+TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) {
   HloComputation::Builder builder(TestName());
   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
       0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
@@ -94,8 +94,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
   EXPECT_EQ(dot, computation->root_instruction());
-  EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
-  EXPECT_THAT(computation->root_instruction(), op::Fusion());
+  EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
 }
 
 TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
@@ -244,35 +243,33 @@ class OpcodeFusionTest : public InstructionFusionTest {
   }
 };
 
-TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
+TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) {
   HloComputation::Builder builder(TestName());
   Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4});
   Shape result_shape = ShapeUtil::MakeShape(F32, {4});
   HloInstruction* param0 = builder.AddInstruction(
       HloInstruction::CreateParameter(0, param_shape, "param"));
-  // InstructionFusion::ShouldFuse() precludes fusing a bitcast whose operand
-  // is a parameter, so create an operand between the parameter and bitcast.
   HloInstruction* exp1 = builder.AddInstruction(
       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
-  HloInstruction* bitcast2 = builder.AddInstruction(
-      HloInstruction::CreateUnary(result_shape, HloOpcode::kBitcast, exp1));
+  HloInstruction* reshape2 =
+      builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1));
   builder.AddInstruction(
-      HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, bitcast2));
+      HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
   RunFusionAndCheckOpcodesWereFused(
-      module.get(), {HloOpcode::kNegate, HloOpcode::kBitcast, HloOpcode::kExp,
+      module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp,
                      HloOpcode::kParameter});
 }
 
-TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) {
+TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) {
   HloComputation::Builder builder(TestName());
   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
   Shape starts_shape = ShapeUtil::MakeShape(F32, {2});
   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8});
-  Shape bitcast_shape = ShapeUtil::MakeShape(F32, {8, 8});
+  Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8});
   Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4});
   HloInstruction* param0 = builder.AddInstruction(
       HloInstruction::CreateParameter(0, param_shape, "param"));
@@ -280,11 +277,11 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) {
       HloInstruction::CreateParameter(1, starts_shape, "starts"));
   HloInstruction* broadcast2 = builder.AddInstruction(
       HloInstruction::CreateBroadcast(broadcast_shape, param0, {1}));
-  HloInstruction* bitcast3 = builder.AddInstruction(HloInstruction::CreateUnary(
-      bitcast_shape, HloOpcode::kBitcast, broadcast2));
+  HloInstruction* reshape3 = builder.AddInstruction(
+      HloInstruction::CreateReshape(reshape_shape, broadcast2));
   HloInstruction* dynamic_slice4 =
       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
-          dynamic_slice_shape, bitcast3, param1, {4, 4}));
+          dynamic_slice_shape, reshape3, param1, {4, 4}));
   builder.AddInstruction(HloInstruction::CreateUnary(
       dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4));
 
@@ -293,7 +290,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) {
 
   RunFusionAndCheckOpcodesWereFused(
       module.get(),
-      {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kBitcast,
+      {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape,
        HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter});
 }
 
index 4929300f7d30ef6fa6c9e128a781e7780f54a520..39f9120e552f014dd2759bff2892157402d9c47a 100644 (file)
@@ -1561,6 +1561,13 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
     // infeeds.  Clearing the layouts here avoids hiding potential bugs in the
     // layout assignment pass that may accidently use the existing layout.
     for (HloInstruction* instruction : computation->instructions()) {
+      if (instruction->opcode() == HloOpcode::kBitcast) {
+        // bitcasts are inherently layout sensitive and so a bitcast instruction
+        // present in the IR before layout assignment is a bug.
+        return InternalError(
+            "Unexpected bitcast operation seen during layout assignment: %s.",
+            instruction->ToString().c_str());
+      }
       if (instruction->opcode() != HloOpcode::kInfeed) {
         LayoutUtil::ClearLayout(instruction->mutable_shape());
       }
index 62feb7c1e9da0b3ecb9c21b876d86935775531d7..4b1c9bad41de8030cf14bc6d1c0db21b9c56c3bf 100644 (file)
@@ -796,5 +796,26 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
   EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
 }
 
+TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
+  auto builder = HloComputation::Builder(TestName());
+  auto constant0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+  builder.AddInstruction(HloInstruction::CreateUnary(
+      constant0->shape(), HloOpcode::kBitcast, constant0));
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape());
+  LayoutAssignment layout_assignment(&computation_layout);
+  Status error_status = layout_assignment.Run(module.get()).status();
+  EXPECT_FALSE(error_status.ok());
+  EXPECT_THAT(
+      error_status.error_message(),
+      ::testing::HasSubstr(
+          "Unexpected bitcast operation seen during layout assignment"));
+}
+
 }  // namespace
 }  // namespace xla