}
}
+std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
+ const int num_tuple_inputs) {
+ auto builder = HloComputation::Builder("benchmark_loop_body");
+ const Shape element_shape = ShapeUtil::MakeShape(F32, {});
+ std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
+ const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
+ std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
+ for (int i = 0; i < num_tuple_inputs; ++i) {
+ gte_nodes[i] = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(element_shape, param, i));
+ }
+ builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
+ return builder.Build();
+}
+
+void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
+ tensorflow::testing::StopTiming();
+ HloModuleConfig config;
+ config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+ CopyInsertion copy_insertion;
+ const Shape element_shape = ShapeUtil::MakeShape(F32, {});
+ std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
+ for (int i = 0; i < num_iters; ++i) {
+ auto builder = HloComputation::Builder("BM_ParallelWhiles");
+ HloModule module("BM_ManyElementTuple", VersionedComputationHandle(),
+ config);
+ for (int j = 0; j < num_tuple_inputs; ++j) {
+ tuple_params[j] = builder.AddInstruction(
+ HloInstruction::CreateParameter(j, element_shape, ""));
+ }
+ HloInstruction* init =
+ builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
+ HloComputation* condition =
+ module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
+ HloComputation* body =
+ module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
+ HloInstruction* xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(init->shape(), condition, body, init));
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::MakeShape(F32, {}), xla_while, 0));
+ module.AddEntryComputation(builder.Build());
+ tensorflow::testing::StartTiming();
+ ASSERT_IS_OK(copy_insertion.Run(&module).status());
+ tensorflow::testing::StopTiming();
+ }
+}
+
BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
+BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
const string& hlo_string = R"(
void HloDataflowAnalysis::Propagate() {
std::queue<HloInstruction*> worklist;
+ tensorflow::gtl::FlatSet<HloInstruction*> workset;
+ auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
+ if (workset.insert(instruction).second) {
+ worklist.push(instruction);
+ }
+ };
for (HloComputation* computation : module_->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- worklist.push(instruction);
+ add_to_worklist(instruction);
}
}
while (!worklist.empty()) {
HloInstruction* instruction = worklist.front();
worklist.pop();
+ workset.erase(workset.find(instruction));
VLOG(3) << "Worklist top: " << instruction->name();
VLOG(3) << ToString();
VLOG(4) << "New value set for " << instruction->name() << ": "
<< GetInstructionValueSet(instruction);
- // Instruction value was updated. Add users to work list.
+ // Instruction value was updated. Add users to work list if we haven't
+ // already.
for (HloInstruction* user : instruction->users()) {
- worklist.push(user);
+ add_to_worklist(user);
// If user sequentially calls a computation, then the respective
// parameter(s) of the computation need to be updated.
// Note that the same instruction can be used in both operand 1 and
// operand 2.
if (user->operand(1) == instruction) {
- worklist.push(user->true_computation()->parameter_instruction(0));
+ add_to_worklist(user->true_computation()->parameter_instruction(0));
}
if (user->operand(2) == instruction) {
- worklist.push(user->false_computation()->parameter_instruction(0));
+ add_to_worklist(user->false_computation()->parameter_instruction(0));
}
} else {
for (HloComputation* called_computation : user->called_computations()) {
call_graph_->GetNode(called_computation);
if (call_graph_node.context() == CallContext::kSequential) {
for (int64 operand_number : user->OperandIndices(instruction)) {
- worklist.push(
+ add_to_worklist(
called_computation->parameter_instruction(operand_number));
}
}
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if ((callsite.instruction()->opcode() == HloOpcode::kCall) ||
(callsite.instruction()->opcode() == HloOpcode::kConditional)) {
- worklist.push(callsite.instruction());
+ add_to_worklist(callsite.instruction());
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Add the while itself, and the body and condition parameters.
- worklist.push(callsite.instruction());
- worklist.push(
+ add_to_worklist(callsite.instruction());
+ add_to_worklist(
callsite.instruction()->while_body()->parameter_instruction(0));
- worklist.push(
+ add_to_worklist(
callsite.instruction()->while_condition()->parameter_instruction(
0));
}