builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Add(param0, op::Constant()));
}
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2)));
}
builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kSubtract, param0, constant));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Add(param0, op::Negate(constant)));
}
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Divide(op::Divide(param0, param1), param2));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Divide(param0, op::Multiply(param1, param2)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Divide(param0, op::Divide(param1, param2)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Divide(op::Multiply(param0, param2), param1));
builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(
computation->root_instruction(),
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(
computation->root_instruction(),
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Divide(param0, op::Exp(param1)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(param0, op::Exp(op::Negate(param1))));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Divide(param0, op::Power(param1, param2)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(param0, op::Power(param1, op::Negate(param2))));
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Divide(param0, op::Power(param1, param2)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
ASSERT_THAT(computation->root_instruction(),
op::Multiply(param0, op::Power(param1, op::Negate(param2))));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
param0, constant));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(param0, op::Divide(op::Constant(), constant)));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
inner_power, exp2));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Power(base, op::Multiply(exp1, exp2)));
}
// numbers.
TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
Shape r0c64 = ShapeUtil::MakeShape(C64, {});
- Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
+ Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
HloComputation::Builder builder(TestName());
HloInstruction* base = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r1f32, "param0"));
+ HloInstruction::CreateParameter(0, r1c64, "param0"));
HloInstruction* exp1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0c64, "param1"));
HloInstruction* exp2 = builder.AddInstruction(
HloInstruction::CreateParameter(2, r0c64, "param2"));
HloInstruction* inner_power = builder.AddInstruction(
- HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
- builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
+ HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
+ builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
inner_power, exp2));
- auto module = CreateNewModule();
- module->AddEntryComputation(builder.Build());
+ module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie());
}
// Test that A/1 is simplified to A for a scalar.
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, div);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, div);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
HloInstruction* cplx = builder.AddInstruction(
HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, cplx);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
HloInstruction* real = builder.AddInstruction(
HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, real);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
HloInstruction* imag = builder.AddInstruction(
HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, imag);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param1);
}
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, add);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Add(param1, param2));
}
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Divide(op::Exp(param0), op::Exp(param1)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Exp(op::Subtract(param0, param1)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Exp(param0), op::Exp(param1)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Exp(op::Add(param0, param1)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Power(op::Exp(param0), param1));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Exp(op::Multiply(param0, param1)));
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Log(op::Power(param0, param1)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Log(param0), param1));
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_EQ(computation->root_instruction(), param0);
}
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Log(op::Divide(op::Exp(param0), op::Exp(param1))));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1));
}
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast());
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, one));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_EQ(computation->root_instruction(), param0);
}
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, two));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0));
}
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
param0, negative_one));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Divide(op::Broadcast(), param0));
dim->set_base_dilation(1);
dim->set_window_reversal(false);
// Create add computation.
- std::unique_ptr<HloModule> module = CreateNewModule();
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
- module->AddEntryComputation(builder.Build());
+ module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
op::Convolution(lhs, rhs));
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
dim->set_base_dilation(1);
}
// Create add computation.
- std::unique_ptr<HloModule> module = CreateNewModule();
HloComputation* add_computation = nullptr;
{
HloComputation::Builder builder(TestName() + ".add");
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module->AddEmbeddedComputation(builder.Build());
+ add_computation = module().AddEmbeddedComputation(builder.Build());
}
builder.AddInstruction(HloInstruction::CreateReduceWindow(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
window, add_computation));
- module->AddEntryComputation(builder.Build());
+ module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
op::ReduceWindow(param, op::Constant()));
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
padding));
- std::unique_ptr<HloModule> module = CreateNewModule();
- module->AddEntryComputation(builder.Build());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
+ module().AddEntryComputation(builder.Build());
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
op::Pad(param, op::Constant()));
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
- EXPECT_THAT(module->entry_computation()->root_instruction(),
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
auto computation = builder.Build();
- auto module = CreateNewModule();
- module->AddEntryComputation(std::move(computation));
+ module().AddEntryComputation(std::move(computation));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
op::Reshape(op::Broadcast(op::Reshape(op))));
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
- EXPECT_THAT(module->entry_computation()->root_instruction(), op);
+ EXPECT_THAT(module().entry_computation()->root_instruction(), op);
}
// Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), input);
}
builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param0);
}
builder.AddInstruction(
HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param0);
}
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(
computation->root_instruction(),
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Concatenate(param0, param0, param1));
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, empty_slice}, 0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Concatenate(empty_literal, empty_slice));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_EQ(computation->root_instruction(), empty_literal);
}
HloInstruction* broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r1f32, param1, {}));
builder.AddInstruction(HloInstruction::CreateConcatenate(
- param0->shape(), {broadcast, param0}, 0));
+ ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1));
}
HloInstruction* copy = builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
// Set to different layouts.
*param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
// Copy has not been removed.
EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
HloInstruction* copy = builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
// Set to same layouts.
*param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
// Copy has been removed.
EXPECT_THAT(computation->root_instruction(), param0);
*reshape->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
// Reshape is not replaced with a bitcast.
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
builder.AddInstruction(HloInstruction::CreateTuple(
{transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Tuple(transformable_reshape, dimensions_wrong_reshape,
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
bitcasting_callback());
- simplifier.Run(module.get()).ValueOrDie();
+ simplifier.Run(&module()).ValueOrDie();
// Verify that only the first reshape is replaced.
EXPECT_THAT(
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
HloOpcode::kMaximum, movable_reshape, zero));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Reshape(param), zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
- simplifier.Run(module.get()).ValueOrDie();
+ simplifier.Run(&module()).ValueOrDie();
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Maximum(param, zero)));
}
HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Reshape(param), zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
- simplifier.Run(module.get()).ValueOrDie();
+ simplifier.Run(&module()).ValueOrDie();
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Reshape(param), zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
- auto module = CreateNewModule();
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ module().AddEntryComputation(builder.Build());
+ EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
}
// Regression test for a bug where if we failed to sink a reshape, we'd set the
builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0}));
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
+ /*broadcast_dimensions=*/{0, 1}));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
- auto module = CreateNewModule();
- module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ module().AddEntryComputation(builder.Build());
+ EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
}
TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
*transpose->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({0, 1, 2, 3});
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
// Verify that the reshape is replaced.
EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
*transpose->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({3, 1, 2, 0});
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
// Verify that the reshape is replaced.
EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Reshape(param0)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
}
ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
HloOpcode::kCopy, copy1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
}
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Transpose(param0));
EXPECT_EQ(std::vector<int64>({2, 1, 0}),
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
builder.AddInstruction(HloInstruction::CreateBroadcast(
- ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3}));
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Broadcast(op::Reshape(param0)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
}
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Broadcast(param0)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
}
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Broadcast(param)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Broadcast(param)));
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
- auto module = CreateNewModule();
- HloComputation* computation = module->AddEntryComputation(builder.Build());
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Broadcast(param)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
EXPECT_THAT(computation->root_instruction()->dimensions(),
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
- auto module = CreateNewModule();
- HloComputation* computation = module->AddEntryComputation(builder.Build());
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Broadcast(param)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
const std::vector<int64> broadcast_dims =
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
- auto module = CreateNewModule();
- HloComputation* computation = module->AddEntryComputation(builder.Build());
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Broadcast(param)));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Broadcast(param)));
call_builder.AddInstruction(
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
- auto module = CreateNewModule();
- module->AddEmbeddedComputation(std::move(dot_computation));
- module->AddEntryComputation(call_builder.Build());
+ module().AddEmbeddedComputation(std::move(dot_computation));
+ module().AddEntryComputation(call_builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
}
// Test that a constant with tuple shape becomes a tuple of constants.
Literal::CreateR1<float>(constant_vector).get()});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Tuple(op::Constant(), op::Constant()));
}
HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
/*slice_sizes=*/{10, 100, 1000}));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Parameter());
}
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::DynamicSlice(op::Parameter(), op::Parameter()));
}
PaddingConfig padding = window_util::MakeSymmetricPadding(
decorate_spatials(param.symmetric_pad_spatials, 0, 0));
+ TF_ASSERT_OK_AND_ASSIGN(
+ const Shape pad_shape,
+ ShapeInference::InferPadShape(input->shape(),
+ ShapeUtil::MakeShape(F32, {}), padding));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
- ShapeUtil::MakeShape(
- F32, decorate_spatials(param.reduce_window_spatials, 128, 2048)),
- input,
+ pad_shape, input,
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
padding));
- std::unique_ptr<HloModule> module = CreateNewModule();
HloComputation* add_computation = nullptr;
{
HloComputation::Builder builder(TestName() + ".add");
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module->AddEmbeddedComputation(builder.Build());
+ add_computation = module().AddEmbeddedComputation(builder.Build());
}
- TF_ASSERT_OK_AND_ASSIGN(
- const Shape output_shape,
- ShapeInference::InferPadShape(input_shape, ShapeUtil::MakeShape(F32, {}),
- padding));
Window window = window_util::MakeWindow(
decorate_spatials(param.reduce_window_spatials, 1, 1));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
+ ShapeInference::InferReduceWindowShape(
+ pad->shape(), zero->shape(), window,
+ add_computation->ComputeProgramShape()));
builder.AddInstruction(HloInstruction::CreateReduceWindow(
output_shape, pad, zero, window, add_computation));
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
ASSERT_TRUE(run_successful);
EXPECT_TRUE(
dot_dnums.add_rhs_contracting_dimensions(0);
builder.AddInstruction(
HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module()));
const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
const bool computation_should_be_modified =
dot_should_be_transformed || (transpose_lhs && transpose_rhs);
};
class DotOfConcatSimplificationTest
- : public HloTestBase,
+ : public HloVerifiedTestBase,
public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
// Test that we transform
builder.AddInstruction(
HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
ASSERT_TRUE(run_successful);
EXPECT_TRUE(
HloInstruction* lhs2 = builder.AddInstruction(
HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
HloInstruction* lhs3 = builder.AddInstruction(
- HloInstruction::CreateParameter(3, lhs2_shape, "lhs3"));
+ HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
HloInstruction* lhs =
builder.AddInstruction(HloInstruction::CreateConcatenate(
lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
- Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.m});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
auto* rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
- /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.m)));
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
builder.AddInstruction(
HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
+ auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
ASSERT_TRUE(run_successful);
EXPECT_TRUE(
ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));