[XLA] Use HloVerifiedTestBase in AlgebraicSimplifierTest
authorSanjoy Das <sanjoy@google.com>
Wed, 7 Feb 2018 07:54:26 +0000 (23:54 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Feb 2018 07:58:44 +0000 (23:58 -0800)
And fix the fallout.  Thanks to asbirlea@ for noticing this!

PiperOrigin-RevId: 184796949

tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
tensorflow/compiler/xla/window_util.cc

index e43ea50..0f08eb3 100644 (file)
@@ -61,13 +61,12 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -83,13 +82,12 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_THAT(root, op::Add(param0, op::Constant()));
 }
@@ -110,13 +108,12 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2)));
 }
@@ -133,13 +130,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -156,13 +152,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -178,13 +173,12 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -200,13 +194,12 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
   builder.AddInstruction(HloInstruction::CreateBinary(
       r0f32, HloOpcode::kSubtract, param0, constant));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_THAT(root, op::Add(param0, op::Negate(constant)));
 }
@@ -226,15 +219,14 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(op::Divide(param0, param1), param2));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(param0, op::Multiply(param1, param2)));
@@ -255,15 +247,14 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(param0, op::Divide(param1, param2)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(op::Multiply(param0, param2), param1));
@@ -289,8 +280,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(
       computation->root_instruction(),
@@ -298,7 +288,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(
       computation->root_instruction(),
@@ -320,15 +310,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(param0, op::Exp(param1)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Multiply(param0, op::Exp(op::Negate(param1))));
@@ -349,15 +338,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(param0, op::Power(param1, param2)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Multiply(param0, op::Power(param1, op::Negate(param2))));
@@ -380,15 +368,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(param0, op::Power(param1, param2)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   ASSERT_THAT(computation->root_instruction(),
               op::Multiply(param0, op::Power(param1, op::Negate(param2))));
@@ -411,12 +398,11 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
                                                       param0, constant));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Multiply(param0, op::Divide(op::Constant(), constant)));
@@ -438,11 +424,10 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
                                                       inner_power, exp2));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   EXPECT_THAT(computation->root_instruction(),
               op::Power(base, op::Multiply(exp1, exp2)));
 }
@@ -451,24 +436,23 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
 // numbers.
 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
   Shape r0c64 = ShapeUtil::MakeShape(C64, {});
-  Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
+  Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
   HloComputation::Builder builder(TestName());
   HloInstruction* base = builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r1f32, "param0"));
+      HloInstruction::CreateParameter(0, r1c64, "param0"));
   HloInstruction* exp1 = builder.AddInstruction(
       HloInstruction::CreateParameter(1, r0c64, "param1"));
   HloInstruction* exp2 = builder.AddInstruction(
       HloInstruction::CreateParameter(2, r0c64, "param2"));
   HloInstruction* inner_power = builder.AddInstruction(
-      HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
-  builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
+      HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
+  builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
                                                       inner_power, exp2));
 
-  auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
+  module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie());
 }
 
 // Test that A/1 is simplified to A for a scalar.
@@ -482,13 +466,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
   HloInstruction* div = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, div);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -504,13 +487,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
   HloInstruction* div = builder.AddInstruction(
       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, div);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -529,13 +511,12 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
   HloInstruction* cplx = builder.AddInstruction(
       HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, cplx);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -554,13 +535,12 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
   HloInstruction* real = builder.AddInstruction(
       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, real);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param0);
 }
@@ -579,13 +559,12 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
   HloInstruction* imag = builder.AddInstruction(
       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, imag);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_EQ(root, param1);
 }
@@ -607,13 +586,12 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
   HloInstruction* add = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, add);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   root = computation->root_instruction();
   EXPECT_THAT(root, op::Add(param1, param2));
 }
@@ -633,15 +611,14 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Divide(op::Exp(param0), op::Exp(param1)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Exp(op::Subtract(param0, param1)));
@@ -662,15 +639,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Multiply(op::Exp(param0), op::Exp(param1)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Exp(op::Add(param0, param1)));
@@ -689,15 +665,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Power(op::Exp(param0), param1));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Exp(op::Multiply(param0, param1)));
@@ -716,15 +691,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) {
   builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Log(op::Power(param0, param1)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Multiply(op::Log(param0), param1));
@@ -741,14 +715,13 @@ TEST_F(AlgebraicSimplifierTest, LnExp) {
   builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_EQ(computation->root_instruction(), param0);
 }
@@ -770,15 +743,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
   builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Log(op::Divide(op::Exp(param0), op::Exp(param1))));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1));
 }
@@ -795,14 +767,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   HloInstruction* root = computation->root_instruction();
   EXPECT_THAT(root, op::Constant());
@@ -820,14 +791,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   HloInstruction* root = computation->root_instruction();
   EXPECT_THAT(root, op::Broadcast());
@@ -849,14 +819,13 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Power(param0, one));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_EQ(computation->root_instruction(), param0);
 }
@@ -872,14 +841,13 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Power(param0, two));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0));
 }
@@ -895,14 +863,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
   builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
                                                       param0, negative_one));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   HloInstruction* root = computation->root_instruction();
   EXPECT_THAT(root, op::Divide(op::Broadcast(), param0));
@@ -941,16 +908,15 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
   dim->set_base_dilation(1);
   dim->set_window_reversal(false);
   // Create add computation.
-  std::unique_ptr<HloModule> module = CreateNewModule();
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
-  module->AddEntryComputation(builder.Build());
+  module().AddEntryComputation(builder.Build());
   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
                                              non_bitcasting_callback());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
+  EXPECT_THAT(module().entry_computation()->root_instruction(),
               op::Convolution(lhs, rhs));
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+  EXPECT_THAT(module().entry_computation()->root_instruction(),
               op::Broadcast(op::Constant()));
 }
 
@@ -969,7 +935,6 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
     dim->set_base_dilation(1);
   }
   // Create add computation.
-  std::unique_ptr<HloModule> module = CreateNewModule();
   HloComputation* add_computation = nullptr;
   {
     HloComputation::Builder builder(TestName() + ".add");
@@ -980,20 +945,20 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
     builder.AddInstruction(
         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
-    add_computation = module->AddEmbeddedComputation(builder.Build());
+    add_computation = module().AddEmbeddedComputation(builder.Build());
   }
   builder.AddInstruction(HloInstruction::CreateReduceWindow(
       ShapeUtil::MakeShape(F32, {5, 2}), param,
       builder.AddInstruction(
           HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
       window, add_computation));
-  module->AddEntryComputation(builder.Build());
+  module().AddEntryComputation(builder.Build());
   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
                                              non_bitcasting_callback());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
+  EXPECT_THAT(module().entry_computation()->root_instruction(),
               op::ReduceWindow(param, op::Constant()));
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+  EXPECT_THAT(module().entry_computation()->root_instruction(),
               op::Broadcast(op::Constant()));
 }
 
@@ -1014,14 +979,13 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
       builder.AddInstruction(
           HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
       padding));
-  std::unique_ptr<HloModule> module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
+  module().AddEntryComputation(builder.Build());
+  EXPECT_THAT(module().entry_computation()->root_instruction(),
               op::Pad(param, op::Constant()));
   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
                                              non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+  EXPECT_THAT(module().entry_computation()->root_instruction(),
               op::Broadcast(op::Constant()));
 }
 
@@ -1039,17 +1003,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
       ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
 
   auto computation = builder.Build();
-  auto module = CreateNewModule();
-  module->AddEntryComputation(std::move(computation));
+  module().AddEntryComputation(std::move(computation));
 
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
+  EXPECT_THAT(module().entry_computation()->root_instruction(),
               op::Reshape(op::Broadcast(op::Reshape(op))));
 
   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
                                              non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
-  EXPECT_THAT(module->entry_computation()->root_instruction(), op);
+  EXPECT_THAT(module().entry_computation()->root_instruction(), op);
 }
 
 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
@@ -1060,14 +1023,13 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
   builder.AddInstruction(
       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), input);
 }
@@ -1081,14 +1043,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
   builder.AddInstruction(
       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), param0);
 }
@@ -1102,14 +1063,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
   builder.AddInstruction(
       HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), param0);
 }
@@ -1132,8 +1092,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
   builder.AddInstruction(HloInstruction::CreateConcatenate(
       result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(
       computation->root_instruction(),
@@ -1141,7 +1100,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Concatenate(param0, param0, param1));
@@ -1163,15 +1122,14 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
   builder.AddInstruction(HloInstruction::CreateConcatenate(
       result_shape, {empty_literal, empty_slice}, 0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Concatenate(empty_literal, empty_slice));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_EQ(computation->root_instruction(), empty_literal);
 }
@@ -1188,14 +1146,13 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
   HloInstruction* broadcast = builder.AddInstruction(
       HloInstruction::CreateBroadcast(r1f32, param1, {}));
   builder.AddInstruction(HloInstruction::CreateConcatenate(
-      param0->shape(), {broadcast, param0}, 0));
+      ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1));
 }
 
@@ -1209,8 +1166,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
   HloInstruction* copy = builder.AddInstruction(
       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   // Set to different layouts.
   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
@@ -1220,7 +1176,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                  non_bitcasting_callback());
-  EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
 
   // Copy has not been removed.
   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
@@ -1236,8 +1192,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
   HloInstruction* copy = builder.AddInstruction(
       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   // Set to same layouts.
   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
@@ -1247,7 +1202,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   // Copy has been removed.
   EXPECT_THAT(computation->root_instruction(), param0);
@@ -1268,14 +1223,13 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
   *reshape->mutable_shape()->mutable_layout() =
       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                  non_bitcasting_callback());
-  EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
 
   // Reshape is not replaced with a bitcast.
   EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
@@ -1314,8 +1268,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
   builder.AddInstruction(HloInstruction::CreateTuple(
       {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Tuple(transformable_reshape, dimensions_wrong_reshape,
@@ -1323,7 +1276,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                  bitcasting_callback());
-  simplifier.Run(module.get()).ValueOrDie();
+  simplifier.Run(&module()).ValueOrDie();
 
   // Verify that only the first reshape is replaced.
   EXPECT_THAT(
@@ -1344,8 +1297,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
                                    HloOpcode::kMaximum, movable_reshape, zero));
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Maximum(op::Reshape(param), zero));
@@ -1353,7 +1305,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  bitcasting_callback());
 
-  simplifier.Run(module.get()).ValueOrDie();
+  simplifier.Run(&module()).ValueOrDie();
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Maximum(param, zero)));
 }
@@ -1371,8 +1323,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
       HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
   builder.AddInstruction(HloInstruction::CreateBinary(
       ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero));
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Maximum(op::Reshape(param), zero));
@@ -1380,7 +1331,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  bitcasting_callback());
 
-  simplifier.Run(module.get()).ValueOrDie();
+  simplifier.Run(&module()).ValueOrDie();
 
   EXPECT_THAT(computation->root_instruction(),
               op::Maximum(op::Reshape(param), zero));
@@ -1405,9 +1356,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  bitcasting_callback());
-  auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  module().AddEntryComputation(builder.Build());
+  EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
 }
 
 // Regression test for a bug where if we failed to sink a reshape, we'd set the
@@ -1424,14 +1374,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
       builder.AddInstruction(HloInstruction::CreateConstant(
           Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
 
-  builder.AddInstruction(HloInstruction::CreateBroadcast(
-      ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0}));
+  builder.AddInstruction(
+      HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
+                                      /*broadcast_dimensions=*/{0, 1}));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  bitcasting_callback());
-  auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  module().AddEntryComputation(builder.Build());
+  EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
 }
 
 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
@@ -1448,14 +1398,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
   *transpose->mutable_shape()->mutable_layout() =
       LayoutUtil::MakeLayout({0, 1, 2, 3});
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                  bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   // Verify that the reshape is replaced.
   EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
@@ -1475,14 +1424,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
   *transpose->mutable_shape()->mutable_layout() =
       LayoutUtil::MakeLayout({3, 1, 2, 0});
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                  bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   // Verify that the reshape is replaced.
   EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
@@ -1501,15 +1449,14 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
   builder.AddInstruction(HloInstruction::CreateReshape(
       ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Reshape(param0)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
 }
@@ -1529,14 +1476,13 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
       HloOpcode::kCopy, copy1));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
 }
@@ -1554,14 +1500,13 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
   builder.AddInstruction(HloInstruction::CreateTranspose(
       ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Transpose(param0));
   EXPECT_EQ(std::vector<int64>({2, 1, 0}),
@@ -1576,17 +1521,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
   auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
       ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
   builder.AddInstruction(HloInstruction::CreateBroadcast(
-      ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3}));
+      ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Broadcast(op::Reshape(param0)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
 }
@@ -1601,15 +1545,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
   builder.AddInstruction(HloInstruction::CreateReshape(
       ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Broadcast(param0)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
 }
@@ -1623,15 +1566,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
   builder.AddInstruction(
       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Broadcast(param)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Broadcast(param)));
@@ -1646,15 +1588,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
   builder.AddInstruction(HloInstruction::CreateReshape(
       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
 
-  auto module = CreateNewModule();
-  HloComputation* computation = module->AddEntryComputation(builder.Build());
+  HloComputation* computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Broadcast(param)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
   EXPECT_THAT(computation->root_instruction()->dimensions(),
@@ -1670,15 +1611,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
   builder.AddInstruction(HloInstruction::CreateReshape(
       ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
 
-  auto module = CreateNewModule();
-  HloComputation* computation = module->AddEntryComputation(builder.Build());
+  HloComputation* computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Broadcast(param)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
   const std::vector<int64> broadcast_dims =
@@ -1696,15 +1636,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
   builder.AddInstruction(HloInstruction::CreateReshape(
       ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
 
-  auto module = CreateNewModule();
-  HloComputation* computation = module->AddEntryComputation(builder.Build());
+  HloComputation* computation = module().AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Broadcast(param)));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Reshape(op::Broadcast(param)));
@@ -2410,12 +2349,11 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
   call_builder.AddInstruction(
       HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
 
-  auto module = CreateNewModule();
-  module->AddEmbeddedComputation(std::move(dot_computation));
-  module->AddEntryComputation(call_builder.Build());
+  module().AddEmbeddedComputation(std::move(dot_computation));
+  module().AddEntryComputation(call_builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
 }
 
 // Test that a constant with tuple shape becomes a tuple of constants.
@@ -2428,12 +2366,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
                           Literal::CreateR1<float>(constant_vector).get()});
   builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   EXPECT_THAT(computation->root_instruction(),
               op::Tuple(op::Constant(), op::Constant()));
 }
@@ -2453,11 +2390,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
           HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
       /*slice_sizes=*/{10, 100, 1000}));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   EXPECT_THAT(computation->root_instruction(), op::Parameter());
 }
 
@@ -2487,11 +2423,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
       builder.AddInstruction(
           HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   EXPECT_THAT(computation->root_instruction(),
               op::DynamicSlice(op::Parameter(), op::Parameter()));
 }
@@ -2554,15 +2489,16 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
 
   PaddingConfig padding = window_util::MakeSymmetricPadding(
       decorate_spatials(param.symmetric_pad_spatials, 0, 0));
+  TF_ASSERT_OK_AND_ASSIGN(
+      const Shape pad_shape,
+      ShapeInference::InferPadShape(input->shape(),
+                                    ShapeUtil::MakeShape(F32, {}), padding));
   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
-      ShapeUtil::MakeShape(
-          F32, decorate_spatials(param.reduce_window_spatials, 128, 2048)),
-      input,
+      pad_shape, input,
       builder.AddInstruction(
           HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
       padding));
 
-  std::unique_ptr<HloModule> module = CreateNewModule();
   HloComputation* add_computation = nullptr;
   {
     HloComputation::Builder builder(TestName() + ".add");
@@ -2573,24 +2509,24 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
     builder.AddInstruction(
         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
-    add_computation = module->AddEmbeddedComputation(builder.Build());
+    add_computation = module().AddEmbeddedComputation(builder.Build());
   }
 
-  TF_ASSERT_OK_AND_ASSIGN(
-      const Shape output_shape,
-      ShapeInference::InferPadShape(input_shape, ShapeUtil::MakeShape(F32, {}),
-                                    padding));
   Window window = window_util::MakeWindow(
       decorate_spatials(param.reduce_window_spatials, 1, 1));
   auto zero = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+  TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
+                          ShapeInference::InferReduceWindowShape(
+                              pad->shape(), zero->shape(), window,
+                              add_computation->ComputeProgramShape()));
   builder.AddInstruction(HloInstruction::CreateReduceWindow(
       output_shape, pad, zero, window, add_computation));
 
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
   ASSERT_TRUE(run_successful);
 
   EXPECT_TRUE(
@@ -2667,11 +2603,10 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
   dot_dnums.add_rhs_contracting_dimensions(0);
   builder.AddInstruction(
       HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module()));
   const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
   const bool computation_should_be_modified =
       dot_should_be_transformed || (transpose_lhs && transpose_rhs);
@@ -2699,7 +2634,7 @@ struct DotOfConcatTestSpec {
 };
 
 class DotOfConcatSimplificationTest
-    : public HloTestBase,
+    : public HloVerifiedTestBase,
       public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
 
 // Test that we transform
@@ -2745,11 +2680,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
   builder.AddInstruction(
       HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
   ASSERT_TRUE(run_successful);
 
   EXPECT_TRUE(
@@ -2790,17 +2724,17 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
   HloInstruction* lhs2 = builder.AddInstruction(
       HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
   HloInstruction* lhs3 = builder.AddInstruction(
-      HloInstruction::CreateParameter(3, lhs2_shape, "lhs3"));
+      HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
 
   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
   HloInstruction* lhs =
       builder.AddInstruction(HloInstruction::CreateConcatenate(
           lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
 
-  Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.m});
+  Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
   auto* rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
-          /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.m)));
+          /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
 
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
@@ -2810,11 +2744,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
   builder.AddInstruction(
       HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
 
-  auto module = CreateNewModule();
-  auto computation = module->AddEntryComputation(builder.Build());
+  auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
   ASSERT_TRUE(run_successful);
   EXPECT_TRUE(
       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
index 55f42ed..93284b8 100644 (file)
@@ -32,6 +32,8 @@ Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) {
     auto* dimension = window.add_dimensions();
     dimension->set_size(size);
     dimension->set_stride(1);
+    dimension->set_base_dilation(1);
+    dimension->set_window_dilation(1);
   }
   return window;
 }