EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
- .ValueOrDie());
+ .ValueOrDie())
+ << module->ToString();
}
// Counts the number of HLO ops with a given op code in the specified module.
.Run(module.get())
.ValueOrDie())
<< module->ToString();
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
+ << module->ToString();
// Make sure the add hasn't been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
.Run(module.get())
.ValueOrDie())
<< module->ToString();
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+ root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
+ op::Subtract(op::Parameter(), op::Parameter())))
+ << module->ToString();
// Make sure we didn't duplicate any adds.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();