From 0c59fdb9497dba218857dbfab5616ee77fdb70b7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 May 2018 09:06:25 -0700 Subject: [PATCH] Pre-factoring: Fix overly specific test expectations to prepare for multi-output fusion. PiperOrigin-RevId: 196514026 --- .../compiler/xla/service/instruction_fusion_test.cc | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 6dd8fa1..cf9673a 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -92,7 +92,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); } // Counts the number of HLO ops with a given op code in the specified module. @@ -151,7 +152,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), op::Parameter())) + << module->ToString(); // Make sure the add hasn't been duplicated. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); @@ -244,7 +249,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Tuple(op::Subtract(op::Parameter(), op::Parameter()), + op::Subtract(op::Parameter(), op::Parameter()))) + << module->ToString(); // Make sure we didn't duplicate any adds. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); -- 2.7.4