Pre-factoring: Fix overly specific test expectations to prepare for multi-output...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 14 May 2018 16:06:25 +0000 (09:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 14 May 2018 16:08:50 +0000 (09:08 -0700)
PiperOrigin-RevId: 196514026

tensorflow/compiler/xla/service/instruction_fusion_test.cc

index 6dd8fa1..cf9673a 100644 (file)
@@ -92,7 +92,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
   EXPECT_FALSE(
       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
           .Run(module.get())
-          .ValueOrDie());
+          .ValueOrDie())
+      << module->ToString();
 }
 
 // Counts the number of HLO ops with a given op code in the specified module.
@@ -151,7 +152,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
           .Run(module.get())
           .ValueOrDie())
       << module->ToString();
-  EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Fusion());
+  EXPECT_THAT(root->fused_expression_root(),
+              op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
+      << module->ToString();
 
   // Make sure the add hasn't been duplicated.
   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
@@ -244,7 +249,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
           .Run(module.get())
           .ValueOrDie())
       << module->ToString();
-  EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+  root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Fusion());
+  EXPECT_THAT(root->fused_expression_root(),
+              op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
+                        op::Subtract(op::Parameter(), op::Parameter())))
+      << module->ToString();
 
   // Make sure we didn't duplicate any adds.
   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();