Improved the performance of the function optimizer.
authorBenoit Steiner <bsteiner@google.com>
Tue, 20 Mar 2018 20:34:02 +0000 (13:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 20:38:35 +0000 (13:38 -0700)
PiperOrigin-RevId: 189799697

tensorflow/core/grappler/optimizers/function_optimizer.cc

index 97effae..2a6b8a3 100644 (file)
@@ -140,19 +140,53 @@ class FakeCPUDevice : public Device {
   Status Sync() override { return Status::OK(); }
 };
 
-Status InlineSymbolicGradient(const NodeDef& node,
-                              const FunctionDefLibrary& library,
-                              GraphDef* inlined_graph) {
-  Env* env = Env::Default();
-  DeviceAttributes attr;
-  attr.set_name("/device:CPU:0");
-  attr.set_device_type("CPU");
-  FakeCPUDevice* dev = new FakeCPUDevice(env, attr);
-  std::vector<Device*> devices;
-  devices.push_back(dev);
-  DeviceMgr dvc_mgr(devices);
-  FunctionLibraryDefinition function_library(OpRegistry::Global(), library);
+class SymbolicGradientEnv {
+ public:
+  SymbolicGradientEnv(int graph_version, const FunctionDefLibrary& library)
+      : graph_version_(graph_version), library_(library) {}
+
+  FunctionLibraryDefinition* function_library() {
+    InitializeIfNeeded();
+    return fld_.get();
+  }
+  FunctionLibraryRuntime* function_library_runtime() {
+    InitializeIfNeeded();
+    return flr_;
+  }
+
+ private:
+  // This initialization is expensive. Do it lazily to avoid paying for it
+  // unless it's needed.
+  void InitializeIfNeeded() {
+    if (flr_) {
+      return;
+    }
+    Env* env = Env::Default();
+    DeviceAttributes attr;
+    attr.set_name("/device:CPU:0");
+    attr.set_device_type("CPU");
+    FakeCPUDevice* dev = new FakeCPUDevice(env, attr);
+    std::vector<Device*> devices;
+    devices.push_back(dev);
+    dvc_mgr_.reset(new DeviceMgr(devices));
+    fld_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), library_));
+    OptimizerOptions optimizer_opts;
+    optimizer_opts.set_do_function_inlining(true);
+    pflr_.reset(new ProcessFunctionLibraryRuntime(
+        dvc_mgr_.get(), env, graph_version_, fld_.get(), optimizer_opts));
+    flr_ = pflr_->GetFLR(dev->name());
+  }
+
+  const int graph_version_;
+  const FunctionDefLibrary& library_;
+  std::unique_ptr<DeviceMgr> dvc_mgr_;
+  std::unique_ptr<FunctionLibraryDefinition> fld_;
+  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+  FunctionLibraryRuntime* flr_ = nullptr;
+};
 
+Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
+                              GraphDef* inlined_graph) {
   GraphDef graph_def;
 
   // Create a node to anchor the gradient inputs
@@ -186,24 +220,18 @@ Status InlineSymbolicGradient(const NodeDef& node,
   }
 
   // Convert the graphdef to a graph
-  OptimizerOptions optimizer_opts;
-  optimizer_opts.set_do_function_inlining(true);
-  ProcessFunctionLibraryRuntime pflr(&dvc_mgr, env,
-                                     inlined_graph->versions().producer(),
-                                     &function_library, optimizer_opts);
-  FunctionLibraryRuntime* flr = pflr.GetFLR(dev->name());
-  CHECK(flr);
   GraphConstructorOptions graph_ctor_opts;
   graph_ctor_opts.allow_internal_ops = true;
   graph_ctor_opts.expect_device_spec = false;
-  Graph graph(function_library);
+  Graph graph(env->function_library());
   TF_RETURN_IF_ERROR(
       ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &graph));
 
   // Recursively inline the functions until there is nothing more to inline. We
   // should at least expand one function.
   int counter = 0;
-  while (counter < 50 && ExpandInlineFunctions(flr, &graph)) {
+  while (counter < 50 &&
+         ExpandInlineFunctions(env->function_library_runtime(), &graph)) {
     ++counter;
   }
 
@@ -279,11 +307,12 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
     return Status::OK();
   }
 
-  *optimized_graph->mutable_versions() = item.graph.versions();
+  SymbolicGradientEnv env(item.graph.versions().producer(),
+                          item.graph.library());
+
   for (const NodeDef& node : item.graph.node()) {
     if (node.op() == "SymbolicGradient") {
-      TF_RETURN_IF_ERROR(
-          InlineSymbolicGradient(node, item.graph.library(), optimized_graph));
+      TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph));
       continue;
     }
     auto it = functions.find(node.op());
@@ -299,6 +328,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   // inlined based on the context in which they're instantiated.
 
   // TODO(bsteiner): trim the library to remove unused function definitions
+  *optimized_graph->mutable_versions() = item.graph.versions();
   *optimized_graph->mutable_library() = item.graph.library();
 
   return Status::OK();