[TF:XLA] Add compiler fuel to mark_for_compilation_pass.
authorJustin Lebar <jlebar@google.com>
Wed, 14 Mar 2018 10:46:37 +0000 (03:46 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 10:51:23 +0000 (03:51 -0700)
This is a useful debugging tool.

PiperOrigin-RevId: 189007771

tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc
tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h
tensorflow/compiler/jit/mark_for_compilation_pass.cc

index 4bc209b..51384ac 100644 (file)
@@ -40,6 +40,7 @@ static void AllocateFlags() {
   flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
   flags->tf_xla_clustering_debug = false;
   flags->tf_xla_cpu_global_jit = false;
+  flags->tf_xla_clustering_fuel = std::numeric_limits<int64>::max();
   flag_list = new std::vector<Flag>(
       {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
             "Control compilation of operators into XLA computations on CPU and "
@@ -55,7 +56,10 @@ static void AllocateFlags() {
        Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
             "Dump graphs during XLA compilation."),
        Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit,
-            "Enables global JIT compilation for CPU via SessionOptions.")});
+            "Enables global JIT compilation for CPU via SessionOptions."),
+       Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel,
+            "Places an artificial limit on the number of ops marked as "
+            "eligible for clustering.")});
   xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
 }
 
index e1ccd7d..170b89c 100644 (file)
@@ -48,6 +48,9 @@ typedef struct {
   bool tf_xla_clustering_debug;   // Dump graphs during XLA compilation.
   bool tf_xla_cpu_global_jit;     // Enables global JIT compilation for CPU
                                   // via SessionOptions.
+  int64 tf_xla_clustering_fuel;   // "Compiler fuel" for clustering.  Only this
+                                  // many ops will be marked as eligible for
+                                  // clustering.
 } MarkForCompilationPassFlags;
 
 // Return a pointer to the MarkForCompilationPassFlags struct;
index e145a21..57fb8d2 100644 (file)
@@ -191,7 +191,27 @@ Status FindCompilationCandidates(
   FunctionLibraryRuntime* lib_runtime =
       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
 
+  int64& fuel =
+      legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
+
+  // Iterate over nodes in sorted order so that compiler fuel is deterministic.
+  // We can't simply pass op_nodes().begin() and op_nodes().end to the
+  // std::vector constructor because they're not proper iterators, with
+  // iterator_traits defined and so on.
+  std::vector<Node*> sorted_nodes;
   for (Node* node : graph.op_nodes()) {
+    sorted_nodes.push_back(node);
+  }
+  std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare());
+
+  for (Node* node : sorted_nodes) {
+    VLOG(2) << "Fuel: " << fuel;
+    if (fuel <= 0) {
+      VLOG(2)
+          << "Hit fuel limit; not marking any remaining ops as clusterable.";
+      break;
+    }
+
     VLOG(2) << "FindCompilationCandidates(): Processing "
             << node->DebugString();
 
@@ -236,7 +256,9 @@ Status FindCompilationCandidates(
       continue;
     }
     candidates->insert(node);
+    --fuel;
   }
+  VLOG(2) << "candidates->size() = " << candidates->size();
   return Status::OK();
 }