From d02745e20c02ba7506a920cc4c8b00415f82ee79 Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Sat, 28 Apr 2018 22:19:22 -0700 Subject: [PATCH] [TF:XLA] - Require a module config when creating an HloModule. - All tests using HloTestBase create a module using CreateNewModule. PiperOrigin-RevId: 194684585 --- tensorflow/compiler/xla/reference_util.cc | 3 +- tensorflow/compiler/xla/service/BUILD | 2 + .../xla/service/algebraic_simplifier_test.cc | 101 ++++++++--------- .../compiler/xla/service/buffer_assignment_test.cc | 12 +- .../compiler/xla/service/graphviz_example.cc | 3 +- .../compiler/xla/service/heap_simulator_test.cc | 6 +- .../compiler/xla/service/hlo_cost_analysis_test.cc | 8 +- .../xla/service/hlo_creation_utils_test.cc | 51 +++++---- tensorflow/compiler/xla/service/hlo_evaluator.cc | 3 +- .../compiler/xla/service/hlo_graph_dumper_test.cc | 18 +-- .../compiler/xla/service/hlo_instruction_test.cc | 122 ++++++++++----------- tensorflow/compiler/xla/service/hlo_module.cc | 6 +- tensorflow/compiler/xla/service/hlo_module.h | 1 - .../compiler/xla/service/transpose_folding_test.cc | 50 ++++----- .../xla/service/zero_sized_hlo_elimination_test.cc | 6 +- tensorflow/compiler/xla/tests/hlo_test_base.cc | 5 +- tensorflow/compiler/xla/tests/hlo_test_base.h | 3 +- 17 files changed, 205 insertions(+), 195 deletions(-) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index df9dbc5..c289c84 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -572,7 +572,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module("ReferenceUtil"); + HloModuleConfig config; + HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f39bfb8..ed0da47 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1330,6 +1330,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -2420,6 +2421,7 @@ tf_cc_test( ":hlo_graph_dumper", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 20c5495..d0c99bf 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -1699,14 +1700,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1732,8 +1733,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -1751,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1766,14 +1767,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1789,14 +1790,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1924,12 +1925,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(b.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - if (!simplifier.Run(&module).ValueOrDie()) { + if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } auto* root = computation->root_instruction(); @@ -2044,15 +2045,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Minimum(param0, min_value), max_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2074,15 +2075,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2105,15 +2106,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2135,15 +2136,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2167,8 +2168,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, fmax, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2176,7 +2177,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2201,8 +2202,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, slice); @@ -2211,10 +2212,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2242,8 +2243,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reshape); @@ -2251,7 +2252,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2260,7 +2261,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2289,7 +2290,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2312,15 +2313,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2341,7 +2342,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2374,7 +2375,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2397,15 +2398,15 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2431,12 +2432,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 513a878..3ec9795 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1641,7 +1641,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1816,7 +1816,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); @@ -1884,7 +1884,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1929,7 +1929,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1994,7 +1994,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -2073,7 +2073,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 0501700..acf6611 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -82,7 +82,8 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // instructions. Sets the computation as the entry to an HLO module and returns // the module. std::unique_ptr MakeBigGraph() { - auto module = MakeUnique("BigGraph"); + HloModuleConfig config; + auto module = MakeUnique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 688a271..e983fd1 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -76,7 +76,8 @@ class HeapSimulatorTracker { HeapSimulatorTracker( const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -94,7 +95,8 @@ class HeapSimulatorTracker { } explicit HeapSimulatorTracker(const string& name) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); } // Similar to the single entry computation constructor above, but runs the diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 3d055b3..81cc7c4 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -370,8 +370,8 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -412,8 +412,8 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 6b681a5..7e7c4f9 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,27 +19,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { using tensorflow::gtl::ArraySlice; -std::unique_ptr CreateModuleWithProgramShape( - PrimitiveType primitive_type, ArraySlice input_shape_dims, - ArraySlice output_shape_dims, HloInstruction** param, - HloComputation** entry_computation) { - Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); - Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims); - std::unique_ptr module = MakeUnique("test"); - *entry_computation = module->AddEntryComputation( - CreateComputationWithSignature({&input_shape}, output_shape, "entry") - .ValueOrDie()); - *param = (*entry_computation)->parameter_instruction(0); - return module; -} - -TEST(HloCreationUtilsTest, CollapseFirst1Dim) { +class HloCreationUtilsTest : public HloTestBase { + protected: + static std::unique_ptr CreateModuleWithProgramShape( + PrimitiveType primitive_type, ArraySlice input_shape_dims, + ArraySlice output_shape_dims, HloInstruction** param, + HloComputation** entry_computation) { + Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); + Shape output_shape = + ShapeUtil::MakeShape(primitive_type, output_shape_dims); + auto module = CreateNewModule("test"); + *entry_computation = module->AddEntryComputation( + CreateComputationWithSignature({&input_shape}, output_shape, "entry") + .ValueOrDie()); + *param = (*entry_computation)->parameter_instruction(0); + return module; + } +}; + +TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; @@ -59,7 +64,7 @@ TEST(HloCreationUtilsTest, CollapseFirst1Dim) { CHECK_EQ(*result_literal, *Literal::CreateR1({3, 4})); } -TEST(HloCreationUtilsTest, CollapseFirst2Dims) { +TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; @@ -84,7 +89,7 @@ TEST(HloCreationUtilsTest, CollapseFirst2Dims) { {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } -TEST(HloCreationUtilsTest, Prepend1DegenerateDim) { +TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; @@ -104,7 +109,7 @@ TEST(HloCreationUtilsTest, Prepend1DegenerateDim) { CHECK_EQ(*result_literal, *Literal::CreateR2({{9, 10}})); } -TEST(HloCreationUtilsTest, Prepend2DegenerateDims) { +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; @@ -124,7 +129,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDims) { CHECK_EQ(*result_literal, *Literal::CreateR3({{{9, 10}}})); } -TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; @@ -144,7 +149,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { CHECK_EQ(*result_literal, *Literal::CreateR2({{9}})); } -TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { +TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; @@ -166,7 +171,7 @@ TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { *Literal::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } -TEST(HloCreationUtilsTest, PadVectorWithZeros) { +TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; @@ -187,7 +192,7 @@ TEST(HloCreationUtilsTest, PadVectorWithZeros) { CHECK_EQ(*result_literal, *Literal::CreateR1({0, 0, 0, 3, 4, 0})); } -TEST(HloCreationUtilsTest, BroadcastZeros_S32) { +TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; @@ -208,7 +213,7 @@ TEST(HloCreationUtilsTest, BroadcastZeros_S32) { CHECK_EQ(*result_literal, *Literal::CreateR2({{0, 0}, {0, 0}})); } -TEST(HloCreationUtilsTest, BroadcastZeros_F32) { +TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index f1dcef1..8cf9412 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -2968,9 +2968,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } Status HloEvaluator::HandleFusion(HloInstruction* fusion) { + HloModuleConfig config; // Attach cloned computation to an empty HLO module so the existing ones are // not modified. - HloModule empty_hlo_module("EmptyModuleForFusion"); + HloModule empty_hlo_module("EmptyModuleForFusion", config); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( /*suffix=*/"clone_with_layout", &empty_hlo_module); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1f00aa4..b589cd5 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -47,7 +48,9 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { XLA_REGISTER_GRAPH_RENDERER(DotRenderer); -TEST(HloGraphDumperTest, NestedFusion) { +class HloGraphDumperTest : public HloTestBase {}; + +TEST_F(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); // Build param0 + param1 + param2 + param3 + param4. @@ -64,10 +67,9 @@ TEST(HloGraphDumperTest, NestedFusion) { sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, sums[i], params[i + 2]))); } - - HloModule m(TestName()); - m.AddEntryComputation(b.Build()); - HloComputation* root_computation = m.entry_computation(); + auto m = CreateNewModule(); + m->AddEntryComputation(b.Build()); + HloComputation* root_computation = m->entry_computation(); // Fuse into fusion(param0 + param1 + param2 + param3 + param4). auto* outer_fusion = root_computation->CreateFusionInstruction( @@ -117,13 +119,13 @@ TEST(HloGraphDumperTest, NestedFusion) { HasSubstr(inner_sum->name())); } -TEST(HloGraphDumperTest, Constant) { +TEST_F(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); instruction->set_name("i_am_a_constant_root_instruction"); - HloModule m(TestName()); - HloComputation* root_computation = m.AddEntryComputation(b.Build()); + auto m = CreateNewModule(); + HloComputation* root_computation = m->AddEntryComputation(b.Build()); string graph = hlo_graph_dumper::DumpGraph( *root_computation, /*label=*/"an_empty_graph", DebugOptions()); EXPECT_THAT(graph, HasSubstr("an_empty_graph")); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index f2980d3..5b65b11 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -149,8 +149,8 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); EXPECT_THAT(foo->users(), UnorderedElementsAre(add)); @@ -186,8 +186,8 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -219,8 +219,8 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -254,8 +254,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(addtotal->Accept(&visitor)); @@ -303,8 +303,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(neg2->Accept(&visitor)); @@ -325,7 +325,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - HloModule module(TestName()); + auto module = CreateNewModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -335,7 +335,7 @@ TEST_F(HloInstructionTest, TrivialMap) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); @@ -343,7 +343,7 @@ TEST_F(HloInstructionTest, TrivialMap) { HloInstruction::CreateParameter(0, f32a100x10, "")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(map->Accept(&visitor)); @@ -373,8 +373,8 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - HloModule module(TestName()); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto module = CreateNewModule(); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); @@ -387,7 +387,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { auto reduce = builder.AddInstruction( HloInstruction::CreateReduce(f32v100, param0, const0, /*dimensions_to_reduce=*/{1}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(reduce->Accept(&visitor)); @@ -414,8 +414,8 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -449,8 +449,8 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar)); @@ -477,8 +477,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log)); @@ -514,8 +514,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -544,8 +544,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(2, bar->user_count()); @@ -609,8 +609,8 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; ASSERT_IS_OK(add->Accept(&visitor)); @@ -627,8 +627,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -645,8 +645,8 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -667,8 +667,8 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -690,8 +690,8 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -746,13 +746,13 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - HloModule module(TestName()); + auto module = CreateNewModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape, "param")); - return module.AddEmbeddedComputation(builder.Build()); + return module->AddEmbeddedComputation(builder.Build()); }; HloComputation* computation_x = make_map_computation(); @@ -767,7 +767,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {map_3_y}, HloInstruction::FusionKind::kLoop); @@ -814,8 +814,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -940,8 +940,8 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); int visit_num = 0; std::unordered_map visit_order; @@ -969,8 +969,8 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); for (int i = 0; i < add->operand_count(); ++i) { @@ -1013,8 +1013,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1056,8 +1056,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1099,8 +1099,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); @@ -1118,7 +1118,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { } TEST_F(HloInstructionTest, FusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1128,7 +1128,7 @@ TEST_F(HloInstructionTest, FusionEquality) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter)); auto neg = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter)); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); auto* fusion2 = computation->CreateFusionInstruction( @@ -1140,7 +1140,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1166,7 +1166,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { data_shape, HloOpcode::kSubtract, dot, add_operand)); builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub)); - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); auto nested_fusion = computation->CreateFusionInstruction( {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); @@ -1244,8 +1244,8 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); @@ -1295,8 +1295,8 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " @@ -1331,8 +1331,8 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*index_vector_dim=*/2), /*window_bounds=*/{30, 29, 28, 27, 26})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 08b9a29..d4bad16 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -41,9 +41,6 @@ HloModule::HloModule(const string& name, entry_computation_handle_(entry_computation_handle), unique_id_(next_unique_module_id_++) {} -HloModule::HloModule(const string& name) - : name_(NameUniquer::GetSanitizedName(name)), - unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), @@ -479,8 +476,7 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix); - module->config_ = config_; + auto module = MakeUnique(name_ + "-" + suffix, config_); module->entry_computation_handle_ = entry_computation_handle_; module->has_entry_computation_handle_ = has_entry_computation_handle_; diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 9f7f252..aa843ea 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -55,7 +55,6 @@ class HloModule { // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. - explicit HloModule(const string& name); explicit HloModule(const string& name, const HloModuleConfig& config); // Adds an entry computation to the module. A module can only have one entry diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index caa1a11..c7c4160 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -71,10 +71,10 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, /*rhs=*/transpose_y, dot_dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(dot)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the fusion. std::unordered_set instruction_set( @@ -114,10 +114,10 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { ShapeUtil::MakeShape(F32, {1, 3}), /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(dot)); + FoldTranspose(module.get()); for (auto* instruction : entry_computation->instructions()) { if (instruction->opcode() == HloOpcode::kFusion) { @@ -149,10 +149,10 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - HloModule module("fuse_with_constant_operands"); + auto module = CreateNewModule("fuse_with_constant_operands"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(mul)); - HloInstruction* call = module.OutlineExpressionFromComputation( + module->AddEntryComputation(builder.Build(mul)); + HloInstruction* call = module->OutlineExpressionFromComputation( {add, sub, mul}, "", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); @@ -182,14 +182,14 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, /*rhs=*/transpose_y, dot_dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); + module->AddEntryComputation(builder.Build(dot)); - HloInstruction* call = module.OutlineExpressionFromComputation( + HloInstruction* call = module->OutlineExpressionFromComputation( {transpose_y, dot}, "outlined", entry_computation); - FoldTranspose(&module); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the fusion. std::unordered_set instruction_set( @@ -240,10 +240,10 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -293,10 +293,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -351,10 +351,10 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -415,10 +415,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index 4f8cdc1..a4e67cc 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -46,9 +46,9 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - HloModule module("zero_sized_elimination_test_module"); - module.AddEntryComputation(builder_.Build()); - return ZeroSizedHloElimination{}.Run(&module); + auto module = CreateNewModule("zero_sized_elimination_test_module"); + module->AddEntryComputation(builder_.Build()); + return ZeroSizedHloElimination{}.Run(module.get()); } HloComputation::Builder builder_; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 9984aba..8b64f2e 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -93,11 +93,10 @@ HloTestBase::HloTestBase(se::Platform* test_platform, } /* static */ -std::unique_ptr HloTestBase::CreateNewModule() { +std::unique_ptr HloTestBase::CreateNewModule(const string& name) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return MakeUnique(name, VersionedComputationHandle(), config); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 79fcea9..6491208 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -85,7 +85,8 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - static std::unique_ptr CreateNewModule(); + static std::unique_ptr CreateNewModule( + const string& name = TestName()); // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in -- 2.7.4