flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
flags->tf_xla_clustering_debug = false;
flags->tf_xla_cpu_global_jit = false;
+ flags->tf_xla_clustering_fuel = std::numeric_limits<int64>::max();
flag_list = new std::vector<Flag>(
{Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
"Control compilation of operators into XLA computations on CPU and "
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit,
- "Enables global JIT compilation for CPU via SessionOptions.")});
+ "Enables global JIT compilation for CPU via SessionOptions."),
+ Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel,
+ "Places an artificial limit on the number of ops marked as "
+ "eligible for clustering.")});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
}
bool tf_xla_clustering_debug; // Dump graphs during XLA compilation.
bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU
// via SessionOptions.
+ int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this
+ // many ops will be marked as eligible for
+ // clustering.
} MarkForCompilationPassFlags;
// Return a pointer to the MarkForCompilationPassFlags struct;
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+ int64& fuel =
+ legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
+
+ // Iterate over nodes in sorted order so that compiler fuel is deterministic.
+ // We can't simply pass op_nodes().begin() and op_nodes().end to the
+ // std::vector constructor because they're not proper iterators, with
+ // iterator_traits defined and so on.
+ std::vector<Node*> sorted_nodes;
for (Node* node : graph.op_nodes()) {
+ sorted_nodes.push_back(node);
+ }
+ std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare());
+
+ for (Node* node : sorted_nodes) {
+ VLOG(2) << "Fuel: " << fuel;
+ if (fuel <= 0) {
+ VLOG(2)
+ << "Hit fuel limit; not marking any remaining ops as clusterable.";
+ break;
+ }
+
VLOG(2) << "FindCompilationCandidates(): Processing "
<< node->DebugString();
continue;
}
candidates->insert(node);
+ --fuel;
}
+ VLOG(2) << "candidates->size() = " << candidates->size();
return Status::OK();
}