From 13ae129449cdeb7afbad98bc8a00ad5c82a0ca31 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 20 Mar 2018 13:34:02 -0700 Subject: [PATCH] Improved the performance of the function optimizer. PiperOrigin-RevId: 189799697 --- .../core/grappler/optimizers/function_optimizer.cc | 78 +++++++++++++++------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 97effae..2a6b8a3 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -140,19 +140,53 @@ class FakeCPUDevice : public Device { Status Sync() override { return Status::OK(); } }; -Status InlineSymbolicGradient(const NodeDef& node, - const FunctionDefLibrary& library, - GraphDef* inlined_graph) { - Env* env = Env::Default(); - DeviceAttributes attr; - attr.set_name("/device:CPU:0"); - attr.set_device_type("CPU"); - FakeCPUDevice* dev = new FakeCPUDevice(env, attr); - std::vector devices; - devices.push_back(dev); - DeviceMgr dvc_mgr(devices); - FunctionLibraryDefinition function_library(OpRegistry::Global(), library); +class SymbolicGradientEnv { + public: + SymbolicGradientEnv(int graph_version, const FunctionDefLibrary& library) + : graph_version_(graph_version), library_(library) {} + + FunctionLibraryDefinition* function_library() { + InitializeIfNeeded(); + return fld_.get(); + } + FunctionLibraryRuntime* function_library_runtime() { + InitializeIfNeeded(); + return flr_; + } + + private: + // This initialization is expensive. Do it lazily to avoid paying for it + // unless it's needed. + void InitializeIfNeeded() { + if (flr_) { + return; + } + Env* env = Env::Default(); + DeviceAttributes attr; + attr.set_name("/device:CPU:0"); + attr.set_device_type("CPU"); + FakeCPUDevice* dev = new FakeCPUDevice(env, attr); + std::vector devices; + devices.push_back(dev); + dvc_mgr_.reset(new DeviceMgr(devices)); + fld_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), library_)); + OptimizerOptions optimizer_opts; + optimizer_opts.set_do_function_inlining(true); + pflr_.reset(new ProcessFunctionLibraryRuntime( + dvc_mgr_.get(), env, graph_version_, fld_.get(), optimizer_opts)); + flr_ = pflr_->GetFLR(dev->name()); + } + + const int graph_version_; + const FunctionDefLibrary& library_; + std::unique_ptr dvc_mgr_; + std::unique_ptr fld_; + std::unique_ptr pflr_; + FunctionLibraryRuntime* flr_ = nullptr; +}; +Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env, + GraphDef* inlined_graph) { GraphDef graph_def; // Create a node to anchor the gradient inputs @@ -186,24 +220,18 @@ Status InlineSymbolicGradient(const NodeDef& node, } // Convert the graphdef to a graph - OptimizerOptions optimizer_opts; - optimizer_opts.set_do_function_inlining(true); - ProcessFunctionLibraryRuntime pflr(&dvc_mgr, env, - inlined_graph->versions().producer(), - &function_library, optimizer_opts); - FunctionLibraryRuntime* flr = pflr.GetFLR(dev->name()); - CHECK(flr); GraphConstructorOptions graph_ctor_opts; graph_ctor_opts.allow_internal_ops = true; graph_ctor_opts.expect_device_spec = false; - Graph graph(function_library); + Graph graph(env->function_library()); TF_RETURN_IF_ERROR( ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &graph)); // Recursively inline the functions until there is nothing more to inline. We // should at least expand one function. int counter = 0; - while (counter < 50 && ExpandInlineFunctions(flr, &graph)) { + while (counter < 50 && + ExpandInlineFunctions(env->function_library_runtime(), &graph)) { ++counter; } @@ -279,11 +307,12 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } - *optimized_graph->mutable_versions() = item.graph.versions(); + SymbolicGradientEnv env(item.graph.versions().producer(), + item.graph.library()); + for (const NodeDef& node : item.graph.node()) { if (node.op() == "SymbolicGradient") { - TF_RETURN_IF_ERROR( - InlineSymbolicGradient(node, item.graph.library(), optimized_graph)); + TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph)); continue; } auto it = functions.find(node.op()); @@ -299,6 +328,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // inlined based on the context in which they're instantiated. // TODO(bsteiner): trim the library to remove unused function definitions + *optimized_graph->mutable_versions() = item.graph.versions(); *optimized_graph->mutable_library() = item.graph.library(); return Status::OK(); -- 2.7.4