[TF:XLA]
authorDimitris Vardoulakis <dimvar@google.com>
Sun, 29 Apr 2018 05:19:22 +0000 (22:19 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 29 Apr 2018 05:21:47 +0000 (22:21 -0700)
- Require a module config when creating an HloModule.
- All tests using HloTestBase create a module using CreateNewModule.

PiperOrigin-RevId: 194684585

17 files changed:
tensorflow/compiler/xla/reference_util.cc
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
tensorflow/compiler/xla/service/buffer_assignment_test.cc
tensorflow/compiler/xla/service/graphviz_example.cc
tensorflow/compiler/xla/service/heap_simulator_test.cc
tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
tensorflow/compiler/xla/service/hlo_instruction_test.cc
tensorflow/compiler/xla/service/hlo_module.cc
tensorflow/compiler/xla/service/hlo_module.h
tensorflow/compiler/xla/service/transpose_folding_test.cc
tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
tensorflow/compiler/xla/tests/hlo_test_base.cc
tensorflow/compiler/xla/tests/hlo_test_base.h

index df9dbc5..c289c84 100644 (file)
@@ -572,7 +572,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
 
   b.AddInstruction(HloInstruction::CreateConvolve(
       shape, lhs_instruction, rhs_instruction, window, dnums));
-  HloModule module("ReferenceUtil");
+  HloModuleConfig config;
+  HloModule module("ReferenceUtil", config);
   auto computation = module.AddEntryComputation(b.Build());
 
   HloEvaluator evaluator;
index f39bfb8..ed0da47 100644 (file)
@@ -1330,6 +1330,7 @@ tf_cc_test(
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:window_util",
         "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
         "//tensorflow/core:lib",
@@ -2420,6 +2421,7 @@ tf_cc_test(
         ":hlo_graph_dumper",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:xla_proto",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
         "//tensorflow/core:lib",
index 20c5495..d0c99bf 100644 (file)
@@ -28,6 +28,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/window_util.h"
@@ -1699,14 +1700,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
   builder.AddInstruction(HloInstruction::CreatePad(
       ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
 
-  HloModule module(TestName());
-  HloComputation* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), param);
 }
@@ -1732,8 +1733,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
       ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
 
-  HloModule module(TestName());
-  HloComputation* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
@@ -1751,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
   EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
   EXPECT_TRUE(has_negative_padding(pad));
 
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
   EXPECT_FALSE(
@@ -1766,14 +1767,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
   builder.AddInstruction(
       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
 
-  HloModule module(TestName());
-  HloComputation* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Reshape(param));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), param);
 }
@@ -1789,14 +1790,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
       ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
       /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
 
-  HloModule module(TestName());
-  HloComputation* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(), op::Slice(param));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(), param);
 }
@@ -1924,12 +1925,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
     b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
                                                     window, dnums));
 
-    HloModule module(TestName());
-    auto* computation = module.AddEntryComputation(b.Build());
+    auto module = CreateNewModule();
+    auto* computation = module->AddEntryComputation(b.Build());
 
     AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
                                    bitcasting_callback());
-    if (!simplifier.Run(&module).ValueOrDie()) {
+    if (!simplifier.Run(module.get()).ValueOrDie()) {
       return "NO_CHANGE";
     }
     auto* root = computation->root_instruction();
@@ -2044,15 +2045,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Maximum(op::Minimum(param0, min_value), max_value));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Clamp(max_value, param0, min_value));
@@ -2074,15 +2075,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Minimum(op::Maximum(param0, max_value), min_value));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Clamp(max_value, param0, min_value));
@@ -2105,15 +2106,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Minimum(op::Maximum(param0, max_value), min_value));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Clamp(max_value, param0, min_value));
@@ -2135,15 +2136,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
   builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Minimum(op::Maximum(param0, max_value), min_value));
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+  EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Minimum(op::Maximum(param0, max_value), min_value));
@@ -2167,8 +2168,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
   builder.AddInstruction(HloInstruction::CreateBinary(
       r0f32, HloOpcode::kMinimum, fmax, min_value));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2176,7 +2177,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+  EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
 
   EXPECT_THAT(computation->root_instruction(),
               op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2201,8 +2202,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
   HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
       slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, slice);
@@ -2211,10 +2212,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
 
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   // Running simplification again should not result in any further changes.
-  ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
 
   root = computation->root_instruction();
   EXPECT_THAT(root, op::Broadcast(scalar_param));
@@ -2242,8 +2243,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
   HloInstruction* reshape = builder.AddInstruction(
       HloInstruction::CreateReshape(reshape_shape, transpose));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, reshape);
@@ -2251,7 +2252,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   root = computation->root_instruction();
   EXPECT_THAT(root, op::Broadcast(forty_two));
@@ -2260,7 +2261,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
 
 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
-  HloModule module(TestName());
+  auto module = CreateNewModule();
   HloComputation::Builder builder(TestName());
 
   // Create operand to the pad.
@@ -2289,7 +2290,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
     builder.AddInstruction(
         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
-    add_computation = module.AddEmbeddedComputation(builder.Build());
+    add_computation = module->AddEmbeddedComputation(builder.Build());
   }
 
   // Create the reduce-window.
@@ -2312,15 +2313,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
           add_computation));
 
   // Build the computation and run the simplifier.
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto computation = module->AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, reduce_window);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   // Running simplification again should not result in any further changes.
-  ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
 
   // Verify the result
   root = computation->root_instruction();
@@ -2341,7 +2342,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
 // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
 // ReduceWindow(Convert(op), x).
 TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
-  HloModule module(TestName());
+  auto module = CreateNewModule();
   HloComputation::Builder builder(TestName());
 
   // Create operand to the pad.
@@ -2374,7 +2375,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
     builder.AddInstruction(
         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
-    add_computation = module.AddEmbeddedComputation(builder.Build());
+    add_computation = module->AddEmbeddedComputation(builder.Build());
   }
 
   // Create the reduce-window.
@@ -2397,15 +2398,15 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
           add_computation));
 
   // Build the computation and run the simplifier.
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto computation = module->AddEntryComputation(builder.Build());
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root, reduce_window);
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   // Running simplification again should not result in any further changes.
-  ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
 
   // Verify the result
   root = computation->root_instruction();
@@ -2431,12 +2432,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
   builder.AddInstruction(
       HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
 
-  HloModule module(TestName());
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
 
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
-  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(a, root);
index 513a878..3ec9795 100644 (file)
@@ -1641,7 +1641,7 @@ static void RunCopyInsertion(HloModule* module) {
 }
 
 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
-  auto module = xla::MakeUnique<HloModule>(TestName());
+  auto module = CreateNewModule();
   auto builder = HloComputation::Builder("entry");
 
   auto input0 = builder.AddInstruction(
@@ -1816,7 +1816,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
   };
 
   // Build the entry computation as described in the comment above.
-  auto module = xla::MakeUnique<HloModule>(TestName());
+  auto module = CreateNewModule();
   auto builder = HloComputation::Builder("entry");
 
   auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
@@ -1884,7 +1884,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
 }
 
 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
-  auto module = xla::MakeUnique<HloModule>(TestName());
+  auto module = CreateNewModule();
   auto builder = HloComputation::Builder("entry");
 
   auto input0 = builder.AddInstruction(
@@ -1929,7 +1929,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
 }
 
 TEST_F(BufferAssignmentTest, TwoCalls) {
-  auto module = xla::MakeUnique<HloModule>(TestName());
+  auto module = CreateNewModule();
   Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
   HloComputation* sub_computation;
   {
@@ -1994,7 +1994,7 @@ static bool IsPostOrderTraversal(
 }
 
 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
-  auto module = xla::MakeUnique<HloModule>(TestName());
+  auto module = CreateNewModule();
   auto builder = HloComputation::Builder(TestName());
 
   auto zero = builder.AddInstruction(
@@ -2073,7 +2073,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
 }
 
 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
-  auto module = xla::MakeUnique<HloModule>(TestName());
+  auto module = CreateNewModule();
   auto builder = HloComputation::Builder("entry");
 
   auto input0 = builder.AddInstruction(
index 0501700..acf6611 100644 (file)
@@ -82,7 +82,8 @@ HloComputation* CallForwardingComputation(HloComputation* computation,
 // instructions. Sets the computation as the entry to an HLO module and returns
 // the module.
 std::unique_ptr<HloModule> MakeBigGraph() {
-  auto module = MakeUnique<HloModule>("BigGraph");
+  HloModuleConfig config;
+  auto module = MakeUnique<HloModule>("BigGraph", config);
 
   auto builder = HloComputation::Builder("TestBigGraphvizGraph");
 
index 688a271..e983fd1 100644 (file)
@@ -76,7 +76,8 @@ class HeapSimulatorTracker {
   HeapSimulatorTracker(
       const string& name, std::unique_ptr<HloComputation> computation,
       const std::vector<const HloInstruction*>& instruction_sequence) {
-    module_ = MakeUnique<HloModule>(name);
+    HloModuleConfig config;
+    module_ = MakeUnique<HloModule>(name, config);
     module_->AddEntryComputation(std::move(computation));
     points_to_analysis_ =
         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
@@ -94,7 +95,8 @@ class HeapSimulatorTracker {
   }
 
   explicit HeapSimulatorTracker(const string& name) {
-    module_ = MakeUnique<HloModule>(name);
+    HloModuleConfig config;
+    module_ = MakeUnique<HloModule>(name, config);
   }
 
   // Similar to the single entry computation constructor above, but runs the
index 3d055b3..81cc7c4 100644 (file)
@@ -370,8 +370,8 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
         HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
     auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
 
-    HloModule module(TestName());
-    auto* computation = module.AddEntryComputation(builder.Build());
+    auto module = CreateNewModule();
+    auto* computation = module->AddEntryComputation(builder.Build());
     auto* fusion = computation->CreateFusionInstruction(
         {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
 
@@ -412,8 +412,8 @@ TEST_F(FusionCostAnalysis, NoLayout) {
   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
       shape_with_layout, HloOpcode::kAdd, c1, broadcast));
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   auto* fusion = computation->CreateFusionInstruction(
       {add, broadcast}, HloInstruction::FusionKind::kLoop);
 
index 6b681a5..7e7c4f9 100644 (file)
@@ -19,27 +19,32 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace xla {
 namespace {
 using tensorflow::gtl::ArraySlice;
 
-std::unique_ptr<HloModule> CreateModuleWithProgramShape(
-    PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
-    ArraySlice<int64> output_shape_dims, HloInstruction** param,
-    HloComputation** entry_computation) {
-  Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
-  Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims);
-  std::unique_ptr<HloModule> module = MakeUnique<HloModule>("test");
-  *entry_computation = module->AddEntryComputation(
-      CreateComputationWithSignature({&input_shape}, output_shape, "entry")
-          .ValueOrDie());
-  *param = (*entry_computation)->parameter_instruction(0);
-  return module;
-}
-
-TEST(HloCreationUtilsTest, CollapseFirst1Dim) {
+class HloCreationUtilsTest : public HloTestBase {
+ protected:
+  static std::unique_ptr<HloModule> CreateModuleWithProgramShape(
+      PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
+      ArraySlice<int64> output_shape_dims, HloInstruction** param,
+      HloComputation** entry_computation) {
+    Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
+    Shape output_shape =
+        ShapeUtil::MakeShape(primitive_type, output_shape_dims);
+    auto module = CreateNewModule("test");
+    *entry_computation = module->AddEntryComputation(
+        CreateComputationWithSignature({&input_shape}, output_shape, "entry")
+            .ValueOrDie());
+    *param = (*entry_computation)->parameter_instruction(0);
+    return module;
+  }
+};
+
+TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -59,7 +64,7 @@ TEST(HloCreationUtilsTest, CollapseFirst1Dim) {
   CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({3, 4}));
 }
 
-TEST(HloCreationUtilsTest, CollapseFirst2Dims) {
+TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -84,7 +89,7 @@ TEST(HloCreationUtilsTest, CollapseFirst2Dims) {
                {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
 }
 
-TEST(HloCreationUtilsTest, Prepend1DegenerateDim) {
+TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -104,7 +109,7 @@ TEST(HloCreationUtilsTest, Prepend1DegenerateDim) {
   CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9, 10}}));
 }
 
-TEST(HloCreationUtilsTest, Prepend2DegenerateDims) {
+TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -124,7 +129,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDims) {
   CHECK_EQ(*result_literal, *Literal::CreateR3<int32>({{{9, 10}}}));
 }
 
-TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
+TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -144,7 +149,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
   CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9}}));
 }
 
-TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
+TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -166,7 +171,7 @@ TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
            *Literal::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
 }
 
-TEST(HloCreationUtilsTest, PadVectorWithZeros) {
+TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -187,7 +192,7 @@ TEST(HloCreationUtilsTest, PadVectorWithZeros) {
   CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
 }
 
-TEST(HloCreationUtilsTest, BroadcastZeros_S32) {
+TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
@@ -208,7 +213,7 @@ TEST(HloCreationUtilsTest, BroadcastZeros_S32) {
   CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{0, 0}, {0, 0}}));
 }
 
-TEST(HloCreationUtilsTest, BroadcastZeros_F32) {
+TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
   HloInstruction* param;
   HloComputation* entry_computation;
 
index f1dcef1..8cf9412 100644 (file)
@@ -2968,9 +2968,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
 }
 
 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
+  HloModuleConfig config;
   // Attach cloned computation to an empty HLO module so the existing ones are
   // not modified.
-  HloModule empty_hlo_module("EmptyModuleForFusion");
+  HloModule empty_hlo_module("EmptyModuleForFusion", config);
   auto cloned_fused_computation =
       fusion->fused_instructions_computation()->Clone(
           /*suffix=*/"clone_with_layout", &empty_hlo_module);
index 1f00aa4..b589cd5 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/xla.pb.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -47,7 +48,9 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
 
 XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
 
-TEST(HloGraphDumperTest, NestedFusion) {
+class HloGraphDumperTest : public HloTestBase {};
+
+TEST_F(HloGraphDumperTest, NestedFusion) {
   HloComputation::Builder b("b");
 
   // Build param0 + param1 + param2 + param3 + param4.
@@ -64,10 +67,9 @@ TEST(HloGraphDumperTest, NestedFusion) {
     sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
         shape, HloOpcode::kAdd, sums[i], params[i + 2])));
   }
-
-  HloModule m(TestName());
-  m.AddEntryComputation(b.Build());
-  HloComputation* root_computation = m.entry_computation();
+  auto m = CreateNewModule();
+  m->AddEntryComputation(b.Build());
+  HloComputation* root_computation = m->entry_computation();
 
   // Fuse into fusion(param0 + param1 + param2 + param3 + param4).
   auto* outer_fusion = root_computation->CreateFusionInstruction(
@@ -117,13 +119,13 @@ TEST(HloGraphDumperTest, NestedFusion) {
       HasSubstr(inner_sum->name()));
 }
 
-TEST(HloGraphDumperTest, Constant) {
+TEST_F(HloGraphDumperTest, Constant) {
   HloComputation::Builder b("b");
   auto instruction = b.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
   instruction->set_name("i_am_a_constant_root_instruction");
-  HloModule m(TestName());
-  HloComputation* root_computation = m.AddEntryComputation(b.Build());
+  auto m = CreateNewModule();
+  HloComputation* root_computation = m->AddEntryComputation(b.Build());
   string graph = hlo_graph_dumper::DumpGraph(
       *root_computation, /*label=*/"an_empty_graph", DebugOptions());
   EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
index f2980d3..5b65b11 100644 (file)
@@ -149,8 +149,8 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) {
       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
   auto add = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
   EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
@@ -186,8 +186,8 @@ TEST_F(HloInstructionTest, MultipleUsers) {
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
   auto add = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(3, foo->user_count());
   EXPECT_EQ(1, bar->user_count());
@@ -219,8 +219,8 @@ TEST_F(HloInstructionTest, RepeatedUser) {
       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
   auto add = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(1, foo->user_count());
 
@@ -254,8 +254,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
   auto addtotal = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   OpAndUserCollectingVisitor visitor;
   ASSERT_IS_OK(addtotal->Accept(&visitor));
@@ -303,8 +303,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
   auto neg2 = builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   OpAndUserCollectingVisitor visitor;
   ASSERT_IS_OK(neg2->Accept(&visitor));
@@ -325,7 +325,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
   //
   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
-  HloModule module(TestName());
+  auto module = CreateNewModule();
 
   // Builds an x+1.0 computation to use in a Map.
   auto embedded_builder = HloComputation::Builder("f32+1");
@@ -335,7 +335,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   embedded_builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
-  auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
+  auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
 
   // Builds a parameter and feeds it to the map.
   HloComputation::Builder builder(TestName());
@@ -343,7 +343,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
       HloInstruction::CreateParameter(0, f32a100x10, ""));
   auto map = builder.AddInstruction(
       HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
-  module.AddEntryComputation(builder.Build());
+  module->AddEntryComputation(builder.Build());
 
   OpAndUserCollectingVisitor visitor;
   ASSERT_IS_OK(map->Accept(&visitor));
@@ -373,8 +373,8 @@ TEST_F(HloInstructionTest, TrivialReduce) {
       HloInstruction::CreateParameter(1, r0f32, "y"));
   embedded_builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
-  HloModule module(TestName());
-  auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
+  auto module = CreateNewModule();
+  auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
 
   // Builds a parameter and an initial value and feeds them to the reduce.
   HloComputation::Builder builder(TestName());
@@ -387,7 +387,7 @@ TEST_F(HloInstructionTest, TrivialReduce) {
   auto reduce = builder.AddInstruction(
       HloInstruction::CreateReduce(f32v100, param0, const0,
                                    /*dimensions_to_reduce=*/{1}, add_f32));
-  module.AddEntryComputation(builder.Build());
+  module->AddEntryComputation(builder.Build());
 
   OpAndUserCollectingVisitor visitor;
   ASSERT_IS_OK(reduce->Accept(&visitor));
@@ -414,8 +414,8 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
                                                       add_foobar, add_foofoo));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(2, foo->user_count());
   EXPECT_EQ(1, bar->user_count());
@@ -449,8 +449,8 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
       builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
   auto add_foobar = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(2, foo->user_count());
   EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
@@ -477,8 +477,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
   auto log = builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(2, foo->user_count());
   EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
@@ -514,8 +514,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
                                                       add_foobar, add_foofoo));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(2, foo->user_count());
   EXPECT_EQ(1, bar->user_count());
@@ -544,8 +544,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
   auto exp = builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(3, foo->user_count());
   EXPECT_EQ(2, bar->user_count());
@@ -609,8 +609,8 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
   auto add = builder.AddInstruction(
       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   NodeCollectorAndPostProcessor visitor;
   ASSERT_IS_OK(add->Accept(&visitor));
@@ -627,8 +627,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) {
       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
   auto exp = builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   auto* fusion = computation->CreateFusionInstruction(
       {exp}, HloInstruction::FusionKind::kLoop);
 
@@ -645,8 +645,8 @@ TEST_F(HloInstructionTest, BinaryFusionOp) {
       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
       r0f32_, HloOpcode::kAdd, constant1, constant2));
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   auto* fusion = computation->CreateFusionInstruction(
       {add}, HloInstruction::FusionKind::kLoop);
 
@@ -667,8 +667,8 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
   auto exp3 = builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   auto* fusion = computation->CreateFusionInstruction(
       {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
 
@@ -690,8 +690,8 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
   exp1->set_metadata(metadata);
   exp2->set_metadata(metadata);
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   auto* fusion = computation->CreateFusionInstruction(
       {exp2, exp1}, HloInstruction::FusionKind::kLoop);
 
@@ -746,13 +746,13 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
 TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
   // Create a fusion instruction containing a single unary operation.
   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
-  HloModule module(TestName());
+  auto module = CreateNewModule();
 
   auto make_map_computation = [&]() {
     auto builder = HloComputation::Builder("FusionMap");
     builder.AddInstruction(
         HloInstruction::CreateParameter(0, scalar_shape, "param"));
-    return module.AddEmbeddedComputation(builder.Build());
+    return module->AddEmbeddedComputation(builder.Build());
   };
 
   HloComputation* computation_x = make_map_computation();
@@ -767,7 +767,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
       scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{}));
   auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
       scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{}));
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto* computation = module->AddEntryComputation(builder.Build());
 
   auto* fusion = computation->CreateFusionInstruction(
       {map_3_y}, HloInstruction::FusionKind::kLoop);
@@ -814,8 +814,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
   auto tuple =
       builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   auto* fusion = computation->CreateFusionInstruction(
       {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
 
@@ -940,8 +940,8 @@ TEST_F(HloInstructionTest, FunctionVisitor) {
       HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
   auto add = builder.AddInstruction(
       HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   int visit_num = 0;
   std::unordered_map<HloInstruction*, int> visit_order;
@@ -969,8 +969,8 @@ TEST_F(HloInstructionTest, FullyElementwise) {
       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
   auto add = builder.AddInstruction(
       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_TRUE(add->IsElementwise());
   for (int i = 0; i < add->operand_count(); ++i) {
@@ -1013,8 +1013,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) {
   HloInstruction* max = builder.AddInstruction(
       HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   HloInstruction* fusion = computation->CreateFusionInstruction(
       {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
   EXPECT_FALSE(fusion->IsElementwise());
@@ -1056,8 +1056,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
       r1f32, HloOpcode::kSubtract, min, broadcast));
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   HloInstruction* fusion = computation->CreateFusionInstruction(
       {sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
   EXPECT_FALSE(fusion->IsElementwise());
@@ -1099,8 +1099,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
   HloInstruction* dot = builder.AddInstruction(
       HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   HloInstruction* fusion = computation->CreateFusionInstruction(
       {dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
 
@@ -1118,7 +1118,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
 }
 
 TEST_F(HloInstructionTest, FusionEquality) {
-  HloModule module(TestName());
+  auto module = CreateNewModule();
   HloComputation::Builder builder(TestName());
 
   // Create two fusion instructions containing a single unary operation.
@@ -1128,7 +1128,7 @@ TEST_F(HloInstructionTest, FusionEquality) {
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
   auto neg = builder.AddInstruction(
       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto* computation = module->AddEntryComputation(builder.Build());
   auto* fusion = computation->CreateFusionInstruction(
       {exp}, HloInstruction::FusionKind::kLoop);
   auto* fusion2 = computation->CreateFusionInstruction(
@@ -1140,7 +1140,7 @@ TEST_F(HloInstructionTest, FusionEquality) {
 }
 
 TEST_F(HloInstructionTest, NestedFusionEquality) {
-  HloModule module(TestName());
+  auto module = CreateNewModule();
   HloComputation::Builder builder(TestName());
 
   // Build a nested fusion computation.
@@ -1166,7 +1166,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
       data_shape, HloOpcode::kSubtract, dot, add_operand));
   builder.AddInstruction(
       HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
-  auto computation = module.AddEntryComputation(builder.Build());
+  auto computation = module->AddEntryComputation(builder.Build());
 
   auto nested_fusion = computation->CreateFusionInstruction(
       {dot, b_t}, HloInstruction::FusionKind::kTransposeDot);
@@ -1244,8 +1244,8 @@ TEST_F(HloInstructionTest, Stringification) {
             "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
             "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
 
-  HloModule module(TestName());
-  auto* computation = module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
   HloInstruction* fusion = computation->CreateFusionInstruction(
       {dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
 
@@ -1295,8 +1295,8 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
               /*index_vector_dim=*/4),
           /*window_bounds=*/{30, 29, 28, 27, 26}));
 
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(gather_instruction->ToString(),
             "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
@@ -1331,8 +1331,8 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
               /*index_vector_dim=*/2),
           /*window_bounds=*/{30, 29, 28, 27, 26}));
 
-  HloModule module(TestName());
-  module.AddEntryComputation(builder.Build());
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
 
   EXPECT_EQ(gather_instruction->ToString(),
             "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
index 08b9a29..d4bad16 100644 (file)
@@ -41,9 +41,6 @@ HloModule::HloModule(const string& name,
       entry_computation_handle_(entry_computation_handle),
       unique_id_(next_unique_module_id_++) {}
 
-HloModule::HloModule(const string& name)
-    : name_(NameUniquer::GetSanitizedName(name)),
-      unique_id_(next_unique_module_id_++) {}
 HloModule::HloModule(const string& name, const HloModuleConfig& config)
     : name_(NameUniquer::GetSanitizedName(name)),
       config_(config),
@@ -479,8 +476,7 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
 
 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
   VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
-  auto module = MakeUnique<HloModule>(name_ + "-" + suffix);
-  module->config_ = config_;
+  auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
   module->entry_computation_handle_ = entry_computation_handle_;
   module->has_entry_computation_handle_ = has_entry_computation_handle_;
 
index 9f7f252..aa843ea 100644 (file)
@@ -55,7 +55,6 @@ class HloModule {
   // only be used for HloModules used outside of the XLA service (eg
   // tests). The versioned handle is used by the service in the compilation
   // cache. A default configuration is created for this module.
-  explicit HloModule(const string& name);
   explicit HloModule(const string& name, const HloModuleConfig& config);
 
   // Adds an entry computation to the module. A module can only have one entry
index caa1a11..c7c4160 100644 (file)
@@ -71,10 +71,10 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) {
       HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
                                 /*rhs=*/transpose_y, dot_dnums));
 
-  HloModule module("test_module");
+  auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(dot));
-  FoldTranspose(&module);
+      module->AddEntryComputation(builder.Build(dot));
+  FoldTranspose(module.get());
 
   // Instructions after folding: x, y, and the fusion.
   std::unordered_set<HloInstruction*> instruction_set(
@@ -114,10 +114,10 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
       ShapeUtil::MakeShape(F32, {1, 3}),
       /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
 
-  HloModule module("test_module");
+  auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(dot));
-  FoldTranspose(&module);
+      module->AddEntryComputation(builder.Build(dot));
+  FoldTranspose(module.get());
 
   for (auto* instruction : entry_computation->instructions()) {
     if (instruction->opcode() == HloOpcode::kFusion) {
@@ -149,10 +149,10 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
   HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
       add->shape(), HloOpcode::kMultiply, add, sub));
 
-  HloModule module("fuse_with_constant_operands");
+  auto module = CreateNewModule("fuse_with_constant_operands");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(mul));
-  HloInstruction* call = module.OutlineExpressionFromComputation(
+      module->AddEntryComputation(builder.Build(mul));
+  HloInstruction* call = module->OutlineExpressionFromComputation(
       {add, sub, mul}, "", entry_computation);
   EXPECT_EQ(call, entry_computation->root_instruction());
   HloComputation* callee_computation = call->to_apply();
@@ -182,14 +182,14 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
       HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
                                 /*rhs=*/transpose_y, dot_dnums));
 
-  HloModule module("test_module");
+  auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(dot));
+      module->AddEntryComputation(builder.Build(dot));
 
-  HloInstruction* call = module.OutlineExpressionFromComputation(
+  HloInstruction* call = module->OutlineExpressionFromComputation(
       {transpose_y, dot}, "outlined", entry_computation);
 
-  FoldTranspose(&module);
+  FoldTranspose(module.get());
 
   // Instructions after folding: x, y, and the fusion.
   std::unordered_set<HloInstruction*> instruction_set(
@@ -240,10 +240,10 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
 
-  HloModule module("test_module");
+  auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(conv));
-  FoldTranspose(&module);
+      module->AddEntryComputation(builder.Build(conv));
+  FoldTranspose(module.get());
 
   // Instructions after folding: x, y, and the convolution.
   std::unordered_set<HloInstruction*> instruction_set(
@@ -293,10 +293,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
 
-  HloModule module("test_module");
+  auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(conv));
-  FoldTranspose(&module);
+      module->AddEntryComputation(builder.Build(conv));
+  FoldTranspose(module.get());
 
   // Instructions after folding: x, y, and the convolution.
   std::unordered_set<HloInstruction*> instruction_set(
@@ -351,10 +351,10 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
 
-  HloModule module("test_module");
+  auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(conv));
-  FoldTranspose(&module);
+      module->AddEntryComputation(builder.Build(conv));
+  FoldTranspose(module.get());
 
   // Instructions after folding: x, y, and the convolution.
   std::unordered_set<HloInstruction*> instruction_set(
@@ -415,10 +415,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
 
-  HloModule module("test_module");
+  auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
-      module.AddEntryComputation(builder.Build(conv));
-  FoldTranspose(&module);
+      module->AddEntryComputation(builder.Build(conv));
+  FoldTranspose(module.get());
 
   // Instructions after folding: x, y, and the convolution.
   std::unordered_set<HloInstruction*> instruction_set(
index 4f8cdc1..a4e67cc 100644 (file)
@@ -46,9 +46,9 @@ class ZeroSizedHloEliminationTest : public HloTestBase {
                 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {}
 
   StatusOr<bool> RunZeroSizedElimination() {
-    HloModule module("zero_sized_elimination_test_module");
-    module.AddEntryComputation(builder_.Build());
-    return ZeroSizedHloElimination{}.Run(&module);
+    auto module = CreateNewModule("zero_sized_elimination_test_module");
+    module->AddEntryComputation(builder_.Build());
+    return ZeroSizedHloElimination{}.Run(module.get());
   }
 
   HloComputation::Builder builder_;
index 9984aba..8b64f2e 100644 (file)
@@ -93,11 +93,10 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
 }
 
 /* static */
-std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
+std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
   HloModuleConfig config;
   config.set_debug_options(GetDebugOptionsForTest());
-  return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
-                               config);
+  return MakeUnique<HloModule>(name, VersionedComputationHandle(), config);
 }
 
 /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
index 79fcea9..6491208 100644 (file)
@@ -85,7 +85,8 @@ class HloTestBase : public ::testing::Test {
   // options from command-line flags. If you want a fresh HloModule object and
   // then add HloComputations to it, it's recommended to use this method in your
   // tests.
-  static std::unique_ptr<HloModule> CreateNewModule();
+  static std::unique_ptr<HloModule> CreateNewModule(
+      const string& name = TestName());
 
   // Populates debug options from command-line flags and adjusts the options for
   // testing. It is recommended to use this when you need to pass in