Optimize functions in the function library.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 26 Apr 2018 19:12:06 +0000 (12:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 19:15:34 +0000 (12:15 -0700)
PiperOrigin-RevId: 194434546

tensorflow/core/common_runtime/graph_execution_state.cc
tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/function_optimizer.cc
tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/core/grappler/optimizers/meta_optimizer_test.cc

index 642d91e..49b1df3 100644 (file)
@@ -76,7 +76,7 @@ GraphExecutionState::~GraphExecutionState() {
     GraphDef* graph_def, const GraphExecutionStateOptions& options,
     std::unique_ptr<GraphExecutionState>* out_state) {
 #ifndef __ANDROID__
-  VLOG(1) << "Graph proto is " << graph_def->DebugString();
+  VLOG(4) << "Graph proto is " << graph_def->DebugString();
 #endif  // __ANDROID__
 
   std::unique_ptr<GraphExecutionState> ret(
@@ -497,11 +497,24 @@ Status GraphExecutionState::OptimizeGraph(
 
     // Merge optimized graph function library with an original library.
     // Optimized graph might have new functions specialized for it's
-    // instantiation context (see Grappler function optimizer).
+    // instantiation context (see Grappler function optimizer), and modified
+    // function body for the existing functions.
+    optimized_flib->reset(new FunctionLibraryDefinition(*flib_def_));
+
+    for (const FunctionDef& fdef : new_graph.library().function()) {
+      const string& func_name = fdef.signature().name();
+
+      if ((*optimized_flib)->Find(func_name)) {
+        VLOG(3) << "Replace function: name=" << func_name;
+        TF_RETURN_IF_ERROR((*optimized_flib)->RemoveFunction(func_name));
+        TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
+      } else {
+        VLOG(3) << "Add new function: name=" << func_name;
+        TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
+      }
+    }
+
     optimized_graph->reset(new Graph(OpRegistry::Global()));
-    optimized_flib->reset(new FunctionLibraryDefinition(OpRegistry::Global(),
-                                                        new_graph.library()));
-    TF_RETURN_IF_ERROR((*optimized_flib)->AddLibrary(*flib_def_));
 
     GraphConstructorOptions opts;
     opts.allow_internal_ops = true;
@@ -540,6 +553,7 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
 
   Status s = OptimizeGraph(options, &optimized_graph, &optimized_flib);
   if (!s.ok()) {
+    VLOG(2) << "Grappler optimization failed. Error: " << s.error_message();
     // Simply copy the original graph and the function library if we couldn't
     // optimize it.
     optimized_graph.reset(new Graph(flib_def_.get()));
index ad2db68..5b5e1e0 100644 (file)
@@ -518,11 +518,13 @@ cc_library(
         ":loop_optimizer",
         ":memory_optimizer",
         ":model_pruner",
+        "//tensorflow/core:core_cpu_base",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler/utils:colocation",
+        "//tensorflow/core/grappler/utils:functions",
         "//tensorflow/core/grappler/utils:topological_sort",
     ],
 )
@@ -539,9 +541,11 @@ tf_cuda_cc_test(
         "//tensorflow/core:tensorflow",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+        "//tensorflow/core/grappler/utils:grappler_test",
     ],
 )
 
index 47e7dc0..3a6de9e 100644 (file)
@@ -579,7 +579,10 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
         continue;
       }
 
-      if (specialize_func && IsParametrized(*func)) {
+      // Do not specialize if function has custom gradient.
+      const string grad_func = ctx.function_library().FindGradient(func_name);
+
+      if (specialize_func && grad_func.empty() && IsParametrized(*func)) {
         // TODO(ezhulenev): Specialize function call if input is a Const or has
         // a known shape. Const input tensors can be pushed into the function
         // body and removed from function inputs.
index c98eef1..c42d614 100644 (file)
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/versions.pb.h"
 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
@@ -29,6 +30,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
 #include "tensorflow/core/grappler/utils/colocation.h"
+#include "tensorflow/core/grappler/utils/functions.h"
 #include "tensorflow/core/grappler/utils/topological_sort.h"
 #include "tensorflow/core/lib/core/status.h"
 
@@ -235,7 +237,75 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
 Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                                GraphDef* optimized_graph) {
   optimization_results_.clear();
+
+  // 1. Optimize main graph
   TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+
+  // 2. Optimize function library
+  FunctionLibraryDefinition flib(OpRegistry::Global(),
+                                 optimized_graph->library());
+
+  // Optimize each function only once.
+  std::unordered_set<string> optimized_funcs;
+  bool optimize_function_library = true;
+
+  while (optimize_function_library) {
+    optimize_function_library = false;
+
+    for (const FunctionDef& func : optimized_graph->library().function()) {
+      const string& func_name = func.signature().name();
+
+      // Skip already optimized functions.
+      if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue;
+
+      // Skip parametrized functions (function type or body is defined only at
+      // function call time by caller node attributes).
+      if (IsParametrized(func)) continue;
+
+      VLOG(3) << "Optimize function: function=" << func_name;
+
+      // Function optimization might specialize nested function calls, so we
+      // have to reset the flag and do at least one more pass over the library.
+      optimize_function_library = true;
+      optimized_funcs.insert(func_name);
+
+      // Make a GrapplerItem from a FunctionDef.
+      GrapplerFunctionItem func_item;
+      TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item));
+
+      // Optimize function body graph.
+      GraphDef optimized_func_graph;
+      TF_RETURN_IF_ERROR(
+          OptimizeGraph(cluster, func_item, &optimized_func_graph));
+
+      // Function body optimization might have created new specialized
+      // functions for each instantiation context. Add them to the library.
+      for (const FunctionDef& func_def :
+           optimized_func_graph.library().function()) {
+        if (flib.Find(func_def.signature().name()) == nullptr) {
+          TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
+        }
+      }
+
+      // Convert optimized graph back to FunctionDef.
+      FunctionDef optimized_func;
+      func_item.SwapFunctionBody(std::move(optimized_func_graph));
+      TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
+
+      // Replace optimized function with a new FunctionDef.
+      TF_RETURN_IF_ERROR(flib.RemoveFunction(func_name));
+      TF_RETURN_IF_ERROR(flib.AddFunctionDef(optimized_func));
+    }
+
+    // If optimized at least one function, update the graph library.
+    if (optimize_function_library) {
+      *optimized_graph->mutable_library() = flib.ToProto();
+    }
+  }
+
+  VLOG(3) << "Optimized " << optimized_funcs.size()
+          << " functions: " << str_util::Join(optimized_funcs, ", ");
+
   return Status::OK();
 }
 
index 9fcf076..887a988 100644 (file)
@@ -16,11 +16,14 @@ limitations under the License.
 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
 
 #include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -28,6 +31,8 @@ namespace tensorflow {
 namespace grappler {
 namespace {
 
+constexpr char kDevice[] = "/device:CPU:0";
+
 class TestOptimizer : public CustomGraphOptimizer {
  public:
   static void SetOptimized(const bool flag_value) { optimized_ = flag_value; }
@@ -59,7 +64,9 @@ bool TestOptimizer::optimized_;
 
 REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
 
-TEST(MetaOptimizerTest, RunsCustomOptimizer) {
+class MetaOptimizerTest : public GrapplerTest {};
+
+TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
   GrapplerItem item;
   CHECK(fake_input.NextItem(&item));
@@ -75,7 +82,7 @@ TEST(MetaOptimizerTest, RunsCustomOptimizer) {
   EXPECT_TRUE(TestOptimizer::IsOptimized());
 }
 
-TEST(MetaOptimizerTest, RunOptimizersTwice) {
+TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
   GrapplerItem item;
   CHECK(fake_input.NextItem(&item));
@@ -89,6 +96,167 @@ TEST(MetaOptimizerTest, RunOptimizersTwice) {
   TF_EXPECT_OK(status);
 }
 
+TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
+  using test::function::NDef;
+
+  // Enable ony function optimization.
+  RewriterConfig rewriter_config;
+  rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+  rewriter_config.set_function_optimization(RewriterConfig::ON);
+  rewriter_config.add_optimizers("function");
+
+  MetaOptimizer optimizer(nullptr, rewriter_config);
+
+  // Define function library:
+  //
+  //   MyMul(x, y)    = x * y
+  //  *MySquare(x)    = MyMul(x, x)
+  //  *MyQuadratic(x) = MySquare(MySquare(x))
+  //
+  //  * - marked as noinline
+
+  FunctionDef mul_func = FunctionDefHelper::Create(
+      "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "mul:z:0"}});
+
+  FunctionDef square_func = FunctionDefHelper::Create(
+      "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "my_mul:z:0"}});
+  (*square_func.mutable_attr())["_noinline"].set_b(true);
+
+  FunctionDef quadratic_func = FunctionDefHelper::Create(
+      "MyQuadratic", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"square"}, "MySquare", {"x"}, {{"T", "$T"}}},
+       {{"quadratic"}, "MySquare", {"square:z"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "quadratic:z:0"}});
+  (*quadratic_func.mutable_attr())["_noinline"].set_b(true);
+
+  // Tensorflow graph:
+  //
+  //   a = tf.Placeholder(tf.float);
+  //   b = tf.Placeholder(tf.int32);
+  //
+  //   square = MySquare(a);        // a^2
+  //   quadratic = MyQuadratic(b);  // b^4
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+       NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+       // Calls into function library
+       NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+       NDef("quadratic", "MyQuadratic", {"b"}, {{"T", DT_INT32}}, kDevice),
+       // Forward outputs
+       NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice),
+       NDef("out_q", "Identity", {"quadratic:0"}, {{"T", DT_INT32}}, kDevice)},
+      // FunctionLib
+      {mul_func, square_func, quadratic_func});
+
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
+                                           output.library());
+
+  // Specialized and optimized functions should be added to the graph.
+  EXPECT_EQ(6, optimized_flib.num_functions());
+
+  // MyQuadratic should be specialized once:
+  //   0. 'quadratic' node in the main graph
+  const string optimized_0 = "MyQuadratic_specialized_for_quadratic";
+
+  // MySquare should be specialized and optimized for 3 instantiations:
+  //   1. 'square' node in the main graph
+  //   2. 'square' node in the MyQuadratic specialization
+  //   3. 'quadratic' node in the MyQuadratic specialization
+
+  const string optimized_1 = "MySquare_specialized_for_square";
+  const string optimized_2 = "MySquare_specialized_for_square_1";
+  const string optimized_3 = "MySquare_specialized_for_quadratic";
+
+  const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
+  const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
+  const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
+  const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3);
+
+  ASSERT_NE(optimized_func_0, nullptr);
+  ASSERT_NE(optimized_func_1, nullptr);
+  ASSERT_NE(optimized_func_2, nullptr);
+  ASSERT_NE(optimized_func_3, nullptr);
+
+  // Graph should call optimized function.
+  int count = 0;
+  for (const NodeDef& node : output.node()) {
+    if (node.name() == "square" && count++) {
+      EXPECT_EQ("MySquare_specialized_for_square", node.op());
+    } else if (node.name() == "quadratic" && count++) {
+      EXPECT_EQ("MyQuadratic_specialized_for_quadratic", node.op());
+    }
+  }
+  EXPECT_EQ(2, count);
+
+  // Specialized MySquare should call specialized functions.
+  count = 0;
+  for (const NodeDef& node : optimized_func_0->node_def()) {
+    if (node.name() == "square" && count++) {
+      EXPECT_EQ(optimized_2, node.op());
+    } else if (node.name() == "quadratic" && count++) {
+      EXPECT_EQ(optimized_3, node.op());
+    }
+  }
+  EXPECT_EQ(2, count);
+
+  const std::vector<const FunctionDef*> optimized_funcs = {
+      optimized_func_1, optimized_func_1, optimized_func_3};
+
+  // MyMul should be inlined into all optimized versions of MySquare.
+  for (const FunctionDef* optimized_func : optimized_funcs) {
+    count = 0;
+    for (const NodeDef& node : optimized_func->node_def()) {
+      if (node.name() == "my_mul/inlined_inputs" && count++) {
+        EXPECT_EQ("IdentityN", node.op());
+        EXPECT_EQ(2, node.input_size());
+        EXPECT_EQ("x:0", node.input(0));
+        EXPECT_EQ("x:0", node.input(1));
+      } else if (node.name() == "my_mul/x" && count++) {
+        EXPECT_EQ("Identity", node.op());
+        EXPECT_EQ(1, node.input_size());
+        EXPECT_EQ("my_mul/inlined_inputs:output:0", node.input(0));
+      } else if (node.name() == "my_mul/y" && count++) {
+        EXPECT_EQ("Identity", node.op());
+        EXPECT_EQ(1, node.input_size());
+        EXPECT_EQ("my_mul/inlined_inputs:output:1", node.input(0));
+      } else if (node.name() == "my_mul/mul" && count++) {
+        EXPECT_EQ("Mul", node.op());
+        EXPECT_EQ(2, node.input_size());
+        EXPECT_EQ("my_mul/x:output:0", node.input(0));
+        EXPECT_EQ("my_mul/y:output:0", node.input(1));
+      } else if (node.name() == "my_mul" && count++) {
+        EXPECT_EQ("IdentityN", node.op());
+        EXPECT_EQ(1, node.input_size());
+        EXPECT_EQ("my_mul/mul:z:0", node.input(0));
+      }
+      EXPECT_TRUE(node.device().empty());
+    }
+    EXPECT_EQ(5, count);
+  }
+
+  item.fetch = {"out_s", "out_q"};
+  item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
+  item.feed.emplace_back("b", test::AsScalar<int>(4));
+  auto tensors_expected = EvaluateFetchNodes(item);
+
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+
+  test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+  test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow