b.AddInstruction(HloInstruction::CreateConvolve(
shape, lhs_instruction, rhs_instruction, window, dnums));
- HloModule module("ReferenceUtil");
+ HloModuleConfig config;
+ HloModule module("ReferenceUtil", config);
auto computation = module.AddEntryComputation(b.Build());
HloEvaluator evaluator;
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
":hlo_graph_dumper",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Reshape(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
/*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Slice(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(b.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
bitcasting_callback());
- if (!simplifier.Run(&module).ValueOrDie()) {
+ if (!simplifier.Run(module.get()).ValueOrDie()) {
return "NO_CHANGE";
}
auto* root = computation->root_instruction();
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Minimum(param0, min_value), max_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kMinimum, fmax, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, slice);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
HloInstruction* reshape = builder.AddInstruction(
HloInstruction::CreateReshape(reshape_shape, transpose));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reshape);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module.AddEmbeddedComputation(builder.Build());
+ add_computation = module->AddEmbeddedComputation(builder.Build());
}
// Create the reduce-window.
add_computation));
// Build the computation and run the simplifier.
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reduce_window);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
// Verify the result
root = computation->root_instruction();
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module.AddEmbeddedComputation(builder.Build());
+ add_computation = module->AddEmbeddedComputation(builder.Build());
}
// Create the reduce-window.
add_computation));
// Build the computation and run the simplifier.
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reduce_window);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
// Verify the result
root = computation->root_instruction();
builder.AddInstruction(
HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);
}
TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
};
// Build the entry computation as described in the comment above.
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
}
TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
}
TEST_F(BufferAssignmentTest, TwoCalls) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
HloComputation* sub_computation;
{
}
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto zero = builder.AddInstruction(
}
TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
// instructions. Sets the computation as the entry to an HLO module and returns
// the module.
std::unique_ptr<HloModule> MakeBigGraph() {
- auto module = MakeUnique<HloModule>("BigGraph");
+ HloModuleConfig config;
+ auto module = MakeUnique<HloModule>("BigGraph", config);
auto builder = HloComputation::Builder("TestBigGraphvizGraph");
HeapSimulatorTracker(
const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) {
- module_ = MakeUnique<HloModule>(name);
+ HloModuleConfig config;
+ module_ = MakeUnique<HloModule>(name, config);
module_->AddEntryComputation(std::move(computation));
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
}
explicit HeapSimulatorTracker(const string& name) {
- module_ = MakeUnique<HloModule>(name);
+ HloModuleConfig config;
+ module_ = MakeUnique<HloModule>(name, config);
}
// Similar to the single entry computation constructor above, but runs the
HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
shape_with_layout, HloOpcode::kAdd, c1, broadcast));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{add, broadcast}, HloInstruction::FusionKind::kLoop);
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
using tensorflow::gtl::ArraySlice;
-std::unique_ptr<HloModule> CreateModuleWithProgramShape(
- PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
- ArraySlice<int64> output_shape_dims, HloInstruction** param,
- HloComputation** entry_computation) {
- Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
- Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims);
- std::unique_ptr<HloModule> module = MakeUnique<HloModule>("test");
- *entry_computation = module->AddEntryComputation(
- CreateComputationWithSignature({&input_shape}, output_shape, "entry")
- .ValueOrDie());
- *param = (*entry_computation)->parameter_instruction(0);
- return module;
-}
-
-TEST(HloCreationUtilsTest, CollapseFirst1Dim) {
+class HloCreationUtilsTest : public HloTestBase {
+ protected:
+ static std::unique_ptr<HloModule> CreateModuleWithProgramShape(
+ PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
+ ArraySlice<int64> output_shape_dims, HloInstruction** param,
+ HloComputation** entry_computation) {
+ Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
+ Shape output_shape =
+ ShapeUtil::MakeShape(primitive_type, output_shape_dims);
+ auto module = CreateNewModule("test");
+ *entry_computation = module->AddEntryComputation(
+ CreateComputationWithSignature({&input_shape}, output_shape, "entry")
+ .ValueOrDie());
+ *param = (*entry_computation)->parameter_instruction(0);
+ return module;
+ }
+};
+
+TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
HloInstruction* param;
HloComputation* entry_computation;
CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({3, 4}));
}
-TEST(HloCreationUtilsTest, CollapseFirst2Dims) {
+TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloInstruction* param;
HloComputation* entry_computation;
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
-TEST(HloCreationUtilsTest, Prepend1DegenerateDim) {
+TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloInstruction* param;
HloComputation* entry_computation;
CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9, 10}}));
}
-TEST(HloCreationUtilsTest, Prepend2DegenerateDims) {
+TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloInstruction* param;
HloComputation* entry_computation;
CHECK_EQ(*result_literal, *Literal::CreateR3<int32>({{{9, 10}}}));
}
-TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
+TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
HloInstruction* param;
HloComputation* entry_computation;
CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9}}));
}
-TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
+TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloInstruction* param;
HloComputation* entry_computation;
*Literal::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
-TEST(HloCreationUtilsTest, PadVectorWithZeros) {
+TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
HloInstruction* param;
HloComputation* entry_computation;
CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
-TEST(HloCreationUtilsTest, BroadcastZeros_S32) {
+TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
HloInstruction* param;
HloComputation* entry_computation;
CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{0, 0}, {0, 0}}));
}
-TEST(HloCreationUtilsTest, BroadcastZeros_F32) {
+TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
HloInstruction* param;
HloComputation* entry_computation;
}
Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
+ HloModuleConfig config;
// Attach cloned computation to an empty HLO module so the existing ones are
// not modified.
- HloModule empty_hlo_module("EmptyModuleForFusion");
+ HloModule empty_hlo_module("EmptyModuleForFusion", config);
auto cloned_fused_computation =
fusion->fused_instructions_computation()->Clone(
/*suffix=*/"clone_with_layout", &empty_hlo_module);
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
-TEST(HloGraphDumperTest, NestedFusion) {
+class HloGraphDumperTest : public HloTestBase {};
+
+TEST_F(HloGraphDumperTest, NestedFusion) {
HloComputation::Builder b("b");
// Build param0 + param1 + param2 + param3 + param4.
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, sums[i], params[i + 2])));
}
-
- HloModule m(TestName());
- m.AddEntryComputation(b.Build());
- HloComputation* root_computation = m.entry_computation();
+ auto m = CreateNewModule();
+ m->AddEntryComputation(b.Build());
+ HloComputation* root_computation = m->entry_computation();
// Fuse into fusion(param0 + param1 + param2 + param3 + param4).
auto* outer_fusion = root_computation->CreateFusionInstruction(
HasSubstr(inner_sum->name()));
}
-TEST(HloGraphDumperTest, Constant) {
+TEST_F(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
instruction->set_name("i_am_a_constant_root_instruction");
- HloModule m(TestName());
- HloComputation* root_computation = m.AddEntryComputation(b.Build());
+ auto m = CreateNewModule();
+ HloComputation* root_computation = m->AddEntryComputation(b.Build());
string graph = hlo_graph_dumper::DumpGraph(
*root_computation, /*label=*/"an_empty_graph", DebugOptions());
EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, foo->user_count());
EXPECT_EQ(1, bar->user_count());
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(1, foo->user_count());
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
auto addtotal = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(addtotal->Accept(&visitor));
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
auto neg2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(neg2->Accept(&visitor));
//
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
- HloModule module(TestName());
+ auto module = CreateNewModule();
// Builds an x+1.0 computation to use in a Map.
auto embedded_builder = HloComputation::Builder("f32+1");
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
- auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
+ auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
// Builds a parameter and feeds it to the map.
HloComputation::Builder builder(TestName());
HloInstruction::CreateParameter(0, f32a100x10, ""));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(map->Accept(&visitor));
HloInstruction::CreateParameter(1, r0f32, "y"));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
- HloModule module(TestName());
- auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
+ auto module = CreateNewModule();
+ auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
// Builds a parameter and an initial value and feeds them to the reduce.
HloComputation::Builder builder(TestName());
auto reduce = builder.AddInstruction(
HloInstruction::CreateReduce(f32v100, param0, const0,
/*dimensions_to_reduce=*/{1}, add_f32));
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(reduce->Accept(&visitor));
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
add_foobar, add_foofoo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_EQ(1, bar->user_count());
builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
auto add_foobar = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
add_foobar, add_foofoo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_EQ(1, bar->user_count());
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, foo->user_count());
EXPECT_EQ(2, bar->user_count());
HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
NodeCollectorAndPostProcessor visitor;
ASSERT_IS_OK(add->Accept(&visitor));
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp}, HloInstruction::FusionKind::kLoop);
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{add}, HloInstruction::FusionKind::kLoop);
auto exp3 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
exp1->set_metadata(metadata);
exp2->set_metadata(metadata);
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp2, exp1}, HloInstruction::FusionKind::kLoop);
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
- HloModule module(TestName());
+ auto module = CreateNewModule();
auto make_map_computation = [&]() {
auto builder = HloComputation::Builder("FusionMap");
builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param"));
- return module.AddEmbeddedComputation(builder.Build());
+ return module->AddEmbeddedComputation(builder.Build());
};
HloComputation* computation_x = make_map_computation();
scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{}));
auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{}));
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{map_3_y}, HloInstruction::FusionKind::kLoop);
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
int visit_num = 0;
std::unordered_map<HloInstruction*, int> visit_order;
builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_TRUE(add->IsElementwise());
for (int i = 0; i < add->operand_count(); ++i) {
HloInstruction* max = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, min, broadcast));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
HloInstruction* dot = builder.AddInstruction(
HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
}
TEST_F(HloInstructionTest, FusionEquality) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create two fusion instructions containing a single unary operation.
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
auto neg = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp}, HloInstruction::FusionKind::kLoop);
auto* fusion2 = computation->CreateFusionInstruction(
}
TEST_F(HloInstructionTest, NestedFusionEquality) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Build a nested fusion computation.
data_shape, HloOpcode::kSubtract, dot, add_operand));
builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
auto nested_fusion = computation->CreateFusionInstruction(
{dot, b_t}, HloInstruction::FusionKind::kTransposeDot);
"%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
"%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
/*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
/*index_vector_dim=*/2),
/*window_bounds=*/{30, 29, 28, 27, 26}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
entry_computation_handle_(entry_computation_handle),
unique_id_(next_unique_module_id_++) {}
-HloModule::HloModule(const string& name)
- : name_(NameUniquer::GetSanitizedName(name)),
- unique_id_(next_unique_module_id_++) {}
HloModule::HloModule(const string& name, const HloModuleConfig& config)
: name_(NameUniquer::GetSanitizedName(name)),
config_(config),
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
- auto module = MakeUnique<HloModule>(name_ + "-" + suffix);
- module->config_ = config_;
+ auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
module->entry_computation_handle_ = entry_computation_handle_;
module->has_entry_computation_handle_ = has_entry_computation_handle_;
// only be used for HloModules used outside of the XLA service (eg
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
- explicit HloModule(const string& name);
explicit HloModule(const string& name, const HloModuleConfig& config);
// Adds an entry computation to the module. A module can only have one entry
HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
/*rhs=*/transpose_y, dot_dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(dot));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(dot));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the fusion.
std::unordered_set<HloInstruction*> instruction_set(
ShapeUtil::MakeShape(F32, {1, 3}),
/*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(dot));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(dot));
+ FoldTranspose(module.get());
for (auto* instruction : entry_computation->instructions()) {
if (instruction->opcode() == HloOpcode::kFusion) {
HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
add->shape(), HloOpcode::kMultiply, add, sub));
- HloModule module("fuse_with_constant_operands");
+ auto module = CreateNewModule("fuse_with_constant_operands");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(mul));
- HloInstruction* call = module.OutlineExpressionFromComputation(
+ module->AddEntryComputation(builder.Build(mul));
+ HloInstruction* call = module->OutlineExpressionFromComputation(
{add, sub, mul}, "", entry_computation);
EXPECT_EQ(call, entry_computation->root_instruction());
HloComputation* callee_computation = call->to_apply();
HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
/*rhs=*/transpose_y, dot_dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(dot));
+ module->AddEntryComputation(builder.Build(dot));
- HloInstruction* call = module.OutlineExpressionFromComputation(
+ HloInstruction* call = module->OutlineExpressionFromComputation(
{transpose_y, dot}, "outlined", entry_computation);
- FoldTranspose(&module);
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the fusion.
std::unordered_set<HloInstruction*> instruction_set(
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {}
StatusOr<bool> RunZeroSizedElimination() {
- HloModule module("zero_sized_elimination_test_module");
- module.AddEntryComputation(builder_.Build());
- return ZeroSizedHloElimination{}.Run(&module);
+ auto module = CreateNewModule("zero_sized_elimination_test_module");
+ module->AddEntryComputation(builder_.Build());
+ return ZeroSizedHloElimination{}.Run(module.get());
}
HloComputation::Builder builder_;
}
/* static */
-std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
+std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
- return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
- config);
+ return MakeUnique<HloModule>(name, VersionedComputationHandle(), config);
}
/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
// options from command-line flags. If you want a fresh HloModule object and
// then add HloComputations to it, it's recommended to use this method in your
// tests.
- static std::unique_ptr<HloModule> CreateNewModule();
+ static std::unique_ptr<HloModule> CreateNewModule(
+ const string& name = TestName());
// Populates debug options from command-line flags and adjusts the options for
// testing. It is recommended to use this when you need to pass in