EXPECT_THAT(computation->root_instruction(), op::Fusion());
}
-TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
+TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) {
HloComputation::Builder builder(TestName());
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(dot, computation->root_instruction());
- EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
- EXPECT_THAT(computation->root_instruction(), op::Fusion());
+ EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
}
TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
}
};
-TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
+TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) {
HloComputation::Builder builder(TestName());
Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4});
Shape result_shape = ShapeUtil::MakeShape(F32, {4});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "param"));
- // InstructionFusion::ShouldFuse() precludes fusing a bitcast whose operand
- // is a parameter, so create an operand between the parameter and bitcast.
HloInstruction* exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
- HloInstruction* bitcast2 = builder.AddInstruction(
- HloInstruction::CreateUnary(result_shape, HloOpcode::kBitcast, exp1));
+ HloInstruction* reshape2 =
+ builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1));
builder.AddInstruction(
- HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, bitcast2));
+ HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
RunFusionAndCheckOpcodesWereFused(
- module.get(), {HloOpcode::kNegate, HloOpcode::kBitcast, HloOpcode::kExp,
+ module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp,
HloOpcode::kParameter});
}
-TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) {
+TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) {
HloComputation::Builder builder(TestName());
Shape param_shape = ShapeUtil::MakeShape(F32, {8});
Shape starts_shape = ShapeUtil::MakeShape(F32, {2});
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8});
- Shape bitcast_shape = ShapeUtil::MakeShape(F32, {8, 8});
+ Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8});
Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "param"));
HloInstruction::CreateParameter(1, starts_shape, "starts"));
HloInstruction* broadcast2 = builder.AddInstruction(
HloInstruction::CreateBroadcast(broadcast_shape, param0, {1}));
- HloInstruction* bitcast3 = builder.AddInstruction(HloInstruction::CreateUnary(
- bitcast_shape, HloOpcode::kBitcast, broadcast2));
+ HloInstruction* reshape3 = builder.AddInstruction(
+ HloInstruction::CreateReshape(reshape_shape, broadcast2));
HloInstruction* dynamic_slice4 =
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
- dynamic_slice_shape, bitcast3, param1, {4, 4}));
+ dynamic_slice_shape, reshape3, param1, {4, 4}));
builder.AddInstruction(HloInstruction::CreateUnary(
dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4));
RunFusionAndCheckOpcodesWereFused(
module.get(),
- {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kBitcast,
+ {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape,
HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter});
}
EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
}
+TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant0->shape(), HloOpcode::kBitcast, constant0));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ LayoutAssignment layout_assignment(&computation_layout);
+ Status error_status = layout_assignment.Run(module.get()).status();
+ EXPECT_FALSE(error_status.ok());
+ EXPECT_THAT(
+ error_status.error_message(),
+ ::testing::HasSubstr(
+ "Unexpected bitcast operation seen during layout assignment"));
+}
+
} // namespace
} // namespace xla