Inline nested function calls.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:31:29 +0000 (14:31 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:37:16 +0000 (14:37 -0700)
PiperOrigin-RevId: 191647899

tensorflow/core/grappler/optimizers/function_optimizer.cc
tensorflow/core/grappler/optimizers/function_optimizer_test.cc

index 2a6b8a3..f1da469 100644 (file)
@@ -32,16 +32,129 @@ limitations under the License.
 
 namespace tensorflow {
 namespace grappler {
+namespace {
+
+class FunctionInliningContext {
+ public:
+  explicit FunctionInliningContext(const GrapplerItem& item)
+      : library_(&item.graph.library()), functions_(InliningCandidates(item)) {}
+
+  const FunctionDefLibrary& Library() const { return *library_; }
+
+  bool HasInlinedFunctions() const { return !functions_.empty(); }
+
+  // Find inlining candidate by name. Return nullptr if not found.
+  const FunctionDef* FindInlinedFunction(const string& name) const {
+    auto it = functions_.find(name);
+    if (it != functions_.end()) {
+      return it->second;
+    } else {
+      return nullptr;
+    }
+  }
+
+ private:
+  std::unordered_map<string, const FunctionDef*> InliningCandidates(
+      const GrapplerItem& item) const {
+    std::unordered_map<string, const FunctionDef*> functions;
+    for (const FunctionDef& func : item.graph.library().function()) {
+      // Don't inline functions marked as noinline
+      if (func.attr().count("_noinline") != 0) {
+        continue;
+      }
+      // Don't touch anything marked XLA to prevent XLA failures further down
+      // the road.
+      if (func.attr().count("_XlaCompile") > 0 &&
+          func.attr().at("_XlaCompile").b()) {
+        continue;
+      }
+      // Can't create IdentityN nodes with no input or output: skip these
+      // functions for now.
+      if (func.signature().input_arg_size() == 0 ||
+          func.signature().output_arg_size() == 0) {
+        continue;
+      }
+      functions[func.signature().name()] = &func;
+    }
+    return functions;
+  }
+
+  const FunctionDefLibrary* library_;
+  std::unordered_map<string, const FunctionDef*> functions_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext);
+};
+
+// Copy input/output argument type to the type_list. Return error if argument
+// type is not explicitly defined, and not specified in function attributes.
+Status CopyArgType(const NodeDef& func_node,
+                   const std::unordered_map<string, AttrValue>& func_attr,
+                   const string& arg_kind, const OpDef::ArgDef& arg,
+                   AttrValue::ListValue* type_list) {
+  if (arg.type() != DT_INVALID) {
+    type_list->add_type(arg.type());
+  } else {
+    auto it = func_attr.find(arg.type_attr());
+    if (it == func_attr.end() || it->second.type() == DT_INVALID) {
+      return errors::InvalidArgument(
+          "Invalid ", arg_kind, " argument ", arg.name(), " for function ",
+          func_node.op(), " instantiated by ", func_node.name());
+    }
+    type_list->add_type(it->second.type());
+  }
+  return Status::OK();
+}
+
+// Add an IdentityN op to hook the function inputs to: this ensures that
+// they're all evaluated before the evaluation of the function body starts.
+Status HookInlinedFunctionInputs(
+    const NodeDef& func_node, const FunctionDef& func,
+    const std::unordered_map<string, AttrValue>& func_attr, NodeDef* inputs) {
+  inputs->set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs"));
+  inputs->set_op("IdentityN");
+  inputs->set_device(func_node.device());
+  *inputs->mutable_input() = func_node.input();
+  AttrValue::ListValue* type_list =
+      (*inputs->mutable_attr())["T"].mutable_list();
+  for (const OpDef::ArgDef& arg : func.signature().input_arg()) {
+    TF_RETURN_IF_ERROR(
+        CopyArgType(func_node, func_attr, "input", arg, type_list));
+  }
+  return Status::OK();
+}
+
+// Add an IdentityN op to hook the function outputs to: this ensures that the
+// function body is fully evaluated before its fanout gets scheduled.
+Status HookInlinedFunctionOutputs(
+    const NodeDef& func_node, const FunctionDef& func,
+    const std::unordered_map<string, AttrValue>& func_attr,
+    const gtl::ArraySlice<string> fetch, NodeDef* outputs) {
+  outputs->set_name(func_node.name());
+  outputs->set_op("IdentityN");
+  outputs->set_device(func_node.device());
+  AttrValue::ListValue* type_list =
+      (*outputs->mutable_attr())["T"].mutable_list();
+  for (int i = 0; i < func.signature().output_arg_size(); ++i) {
+    const OpDef::ArgDef& arg = func.signature().output_arg(i);
+    TF_RETURN_IF_ERROR(
+        CopyArgType(func_node, func_attr, "output", arg, type_list));
+    // Use the fetch names since they take into account the output mapping.
+    outputs->add_input(strings::StrCat(func_node.name(), "/", fetch[i]));
+  }
+  return Status::OK();
+}
+
+Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
+                      const FunctionInliningContext& ctx,
+                      GraphDef* optimized_graph) {
+  const std::unordered_map<string, AttrValue> func_attr(
+      func_node.attr().begin(), func_node.attr().end());
 
-Status InlineFunction(const NodeDef& node, const FunctionDef& func,
-                      const FunctionDefLibrary& library, GraphDef* graph) {
-  const std::unordered_map<string, AttrValue> attr(node.attr().begin(),
-                                                   node.attr().end());
   std::unique_ptr<GrapplerItem> item =
-      GrapplerItemFromFunctionDef(func, attr, library);
+      GrapplerItemFromFunctionDef(func, func_attr, ctx.Library());
   if (!item) {
-    return errors::InvalidArgument("Failed to inline function ", node.op(),
-                                   " instantiated by ", node.name());
+    return errors::InvalidArgument("Failed to inline function ", func_node.op(),
+                                   " instantiated by ", func_node.name());
   }
 
   std::unordered_map<string, int> input_nodes;
@@ -50,43 +163,25 @@ Status InlineFunction(const NodeDef& node, const FunctionDef& func,
     input_nodes[arg.name()] = i;
   }
 
-  // Add an IdentityN op to hook the function inputs to: this ensures that
-  // they're all evaluated before the evaluation of the function body starts.
-  NodeDef* func_inputs = graph->add_node();
-  func_inputs->set_name(strings::StrCat(node.name(), "/", "inlined_inputs"));
-  func_inputs->set_op("IdentityN");
-  func_inputs->set_device(node.device());
-  *func_inputs->mutable_input() = node.input();
-  AttrValue::ListValue* type_list =
-      (*func_inputs->mutable_attr())["T"].mutable_list();
-  for (const OpDef::ArgDef& arg : func.signature().input_arg()) {
-    if (arg.type() != DT_INVALID) {
-      type_list->add_type(arg.type());
-    } else {
-      auto it = attr.find(arg.type_attr());
-      if (it == attr.end()) {
-        return errors::InvalidArgument("Invalid input argument ", arg.name(),
-                                       " for function ", node.op(),
-                                       " instantiated by ", node.name());
-      }
-      type_list->add_type(it->second.type());
-    }
-  }
+  // Hook inlined function inputs to IdentityN node
+  NodeDef* func_inputs = optimized_graph->add_node();
+  TF_RETURN_IF_ERROR(
+      HookInlinedFunctionInputs(func_node, func, func_attr, func_inputs));
 
   for (NodeDef& func_body_node : *item->graph.mutable_node()) {
     if (input_nodes.find(func_body_node.name()) != input_nodes.end()) {
+      CHECK_EQ(0, func_body_node.input_size());
       // Turn input placeholders into identity nodes
       if (IsPlaceholder(func_body_node)) {
         func_body_node.set_op("Identity");
       }
-      CHECK_EQ(0, func_body_node.input_size());
       int input_id = input_nodes[func_body_node.name()];
       func_body_node.add_input(
           strings::StrCat(func_inputs->name(), ":", input_id));
     } else {
       // Update the input names if any.
       for (string& input : *func_body_node.mutable_input()) {
-        input = AddPrefixToNodeName(input, node.name());
+        input = AddPrefixToNodeName(input, /*prefix=*/func_node.name());
       }
       // If the node has no input, make hook it up to the func_inputs node to
       // ensure it runs in the same frame as the other nodes of the function
@@ -98,39 +193,29 @@ Status InlineFunction(const NodeDef& node, const FunctionDef& func,
 
     // Add the node name as a prefix to avoid collisions after inlining
     func_body_node.set_name(
-        strings::StrCat(node.name(), "/", func_body_node.name()));
+        strings::StrCat(func_node.name(), "/", func_body_node.name()));
 
     // Make sure the node is placed
-    func_body_node.set_device(node.device());
-
-    // Move the node to the main graph
-    graph->add_node()->Swap(&func_body_node);
-  }
-
-  // Add an IdentityN op to hook the function outputs to: this ensures that the
-  // function body is fully evaluated before its fanout gets scheduled.
-  NodeDef* func_outputs = graph->add_node();
-  func_outputs->set_name(node.name());
-  func_outputs->set_op("IdentityN");
-  func_outputs->set_device(node.device());
-  type_list = (*func_outputs->mutable_attr())["T"].mutable_list();
-  for (int i = 0; i < func.signature().output_arg_size(); ++i) {
-    const OpDef::ArgDef& arg = func.signature().output_arg(i);
-    if (arg.type() != DT_INVALID) {
-      type_list->add_type(arg.type());
+    func_body_node.set_device(func_node.device());
+
+    // Check if a body node is itself a function
+    const FunctionDef* func_body_node_func =
+        ctx.FindInlinedFunction(func_body_node.op());
+    if (func_body_node_func != nullptr) {
+      // Recursively inline function calls
+      TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
+                                        ctx, optimized_graph));
     } else {
-      auto it = attr.find(arg.type_attr());
-      if (it == attr.end()) {
-        return errors::InvalidArgument("Invalid output argument ", arg.name(),
-                                       " for function ", node.op(),
-                                       " instantiated by ", node.name());
-      }
-      type_list->add_type(it->second.type());
+      // Move the node to the main graph
+      optimized_graph->add_node()->Swap(&func_body_node);
     }
-    // Use the fetch names since they take into account the output mapping.
-    func_outputs->add_input(strings::StrCat(node.name(), "/", item->fetch[i]));
   }
 
+  // Hook inlined function outputs to IdentityN node
+  NodeDef* func_outputs = optimized_graph->add_node();
+  TF_RETURN_IF_ERROR(HookInlinedFunctionOutputs(func_node, func, func_attr,
+                                                item->fetch, func_outputs));
+
   return Status::OK();
 }
 
@@ -278,31 +363,14 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
   return Status::OK();
 }
 
+}  // namespace
+
 Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                                    GraphDef* optimized_graph) {
-  std::unordered_map<string, const FunctionDef*> functions;
-  for (const FunctionDef& func : item.graph.library().function()) {
-    // Don't inline functions marked as noinline
-    if (func.attr().count("_noinline") != 0) {
-      continue;
-    }
-    // Don't touch anything marked XLA to prevent XLA failures further down the
-    // road.
-    if (func.attr().count("_XlaCompile") > 0 &&
-        func.attr().at("_XlaCompile").b()) {
-      continue;
-    }
-    // Can't create IdentityN nodes with no input or output: skip these
-    // functions for now.
-    if (func.signature().input_arg_size() == 0 ||
-        func.signature().output_arg_size() == 0) {
-      continue;
-    }
-    functions[func.signature().name()] = &func;
-  }
+  FunctionInliningContext function_inlining_ctx(item);
 
-  // Nothing to do.
-  if (functions.empty()) {
+  // Nothing to do here.
+  if (!function_inlining_ctx.HasInlinedFunctions()) {
     *optimized_graph = item.graph;
     return Status::OK();
   }
@@ -315,12 +383,14 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
       TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph));
       continue;
     }
-    auto it = functions.find(node.op());
-    if (it == functions.end()) {
-      *optimized_graph->add_node() = node;
+
+    const FunctionDef* func =
+        function_inlining_ctx.FindInlinedFunction(node.op());
+    if (func != nullptr) {
+      TF_RETURN_IF_ERROR(
+          InlineFunction(node, *func, function_inlining_ctx, optimized_graph));
     } else {
-      TF_RETURN_IF_ERROR(InlineFunction(node, *it->second, item.graph.library(),
-                                        optimized_graph));
+      *optimized_graph->add_node() = node;
     }
   }
 
index deb2fab..c804d75 100644 (file)
@@ -26,7 +26,22 @@ namespace tensorflow {
 namespace grappler {
 namespace {
 
-class FunctionOptimizerTest : public GrapplerTest {};
+constexpr char kDevice[] = "/device:CPU:0";
+
+class FunctionOptimizerTest : public GrapplerTest {
+ protected:
+  Tensor MakeScalarTensor(float value) {
+    Tensor tensor(DT_FLOAT, {});
+    tensor.scalar<float>()() = value;
+    return tensor;
+  }
+
+  Tensor MakeScalarTensor(int value) {
+    Tensor tensor(DT_INT32, {});
+    tensor.scalar<int>()() = value;
+    return tensor;
+  }
+};
 
 TEST_F(FunctionOptimizerTest, SimpleFunction) {
   // Build a graph to compute y = XTimesTwo(x)
@@ -94,9 +109,8 @@ TEST_F(FunctionOptimizerTest, SimpleFunction) {
   }
   EXPECT_EQ(7, count);
 
+  Tensor pi = MakeScalarTensor(3.14f);
   item.fetch = {"z"};
-  Tensor pi(DT_FLOAT, {});
-  pi.flat<float>()(0) = 3.14f;
   item.feed.emplace_back("x", pi);
   auto tensors_expected = EvaluateFetchNodes(item);
   GrapplerItem optimized(item, std::move(output));
@@ -183,9 +197,8 @@ TEST_F(FunctionOptimizerTest, FixedTypeFunction) {
   }
   EXPECT_EQ(6, count);
 
+  Tensor pi = MakeScalarTensor(3.14f);
   item.fetch = {"z"};
-  Tensor pi(DT_FLOAT, {});
-  pi.flat<float>()(0) = 3.14f;
   item.feed.emplace_back("x", pi);
   auto tensors_expected = EvaluateFetchNodes(item);
   GrapplerItem optimized(item, std::move(output));
@@ -268,9 +281,8 @@ TEST_F(FunctionOptimizerTest, FunctionWithOutputMapping) {
   }
   EXPECT_EQ(6, count);
 
+  Tensor pi = MakeScalarTensor(3.14f);
   item.fetch = {"z"};
-  Tensor pi(DT_FLOAT, {});
-  pi.flat<float>()(0) = 3.14f;
   item.feed.emplace_back("x", pi);
   auto tensors_expected = EvaluateFetchNodes(item);
   GrapplerItem optimized(item, std::move(output));
@@ -325,18 +337,11 @@ TEST_F(FunctionOptimizerTest, FunctionWithInputForwarding) {
   TF_EXPECT_OK(status);
 
   item.fetch = {"z0", "z1", "z2"};
-  Tensor in(DT_FLOAT, {});
-  in.flat<float>()(0) = 3.14f;
-  item.feed.emplace_back("x0", in);
-  in.flat<float>()(0) = 2.7f;
-  item.feed.emplace_back("x1", in);
-  in.flat<float>()(0) = 1.0f;
-  item.feed.emplace_back("x2", in);
-  in.flat<float>()(0) = -1.0f;
-  item.feed.emplace_back("x4", in);
-  Tensor in_int(DT_INT32, {});
-  in_int.flat<int>()(0) = 1234;
-  item.feed.emplace_back("x3", in_int);
+  item.feed.emplace_back("x0", MakeScalarTensor(3.14f));
+  item.feed.emplace_back("x1", MakeScalarTensor(2.7f));
+  item.feed.emplace_back("x2", MakeScalarTensor(1.0f));
+  item.feed.emplace_back("x4", MakeScalarTensor(-1.0f));
+  item.feed.emplace_back("x3", MakeScalarTensor(1234));
   auto tensors_expected = EvaluateFetchNodes(item);
   GrapplerItem optimized(item, std::move(output));
   auto tensors = EvaluateFetchNodes(optimized);
@@ -379,6 +384,100 @@ TEST_F(FunctionOptimizerTest, FunctionWithoutInput) {
   EXPECT_EQ(item.graph.DebugString(), output.DebugString());
 }
 
+TEST_F(FunctionOptimizerTest, InlineFunctionWithNestedFunctionCall) {
+  // Define square via function library:
+  //   MySquare(x) = MyMul(x, x)
+
+  FunctionDef mul_func = FunctionDefHelper::Create(
+      "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "output:z:0"}});
+
+  FunctionDef square_func = FunctionDefHelper::Create(
+      "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "output:z:0"}});
+
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {test::function::NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}},
+                            kDevice),
+       test::function::NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}},
+                            kDevice),
+       test::function::NDef("outputs", "Identity", {"square:0"},
+                            {{"T", DT_FLOAT}}, kDevice)},
+      // FunctionLib
+      {mul_func, square_func});
+
+  GraphDef output;
+  FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  int count = 0;
+  for (const NodeDef& node : output.node()) {
+    if (node.name() == "square/inlined_inputs" && count++) {
+      EXPECT_EQ("IdentityN", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("a", node.input(0));
+    } else if (node.name() == "square/x" && count++) {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("square/inlined_inputs:0", node.input(0));
+    } else if (node.name() == "square/output/inlined_inputs" && count++) {
+      EXPECT_EQ("IdentityN", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("square/x", node.input(0));
+      EXPECT_EQ("square/x", node.input(1));
+    } else if (node.name() == "square/output/x" && count++) {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("square/output/inlined_inputs:0", node.input(0));
+    } else if (node.name() == "square/output/y" && count++) {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("square/output/inlined_inputs:1", node.input(0));
+    } else if (node.name() == "square/output/output" && count++) {
+      EXPECT_EQ("Mul", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("square/output/x", node.input(0));
+      EXPECT_EQ("square/output/y", node.input(1));
+    } else if (node.name() == "square/output" && count++) {
+      EXPECT_EQ("IdentityN", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("square/output/output:0", node.input(0));
+    } else if (node.name() == "square" && count++) {
+      EXPECT_EQ("IdentityN", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("square/output:0", node.input(0));
+    } else if (node.name() == "outputs" && count++) {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(kDevice, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("square:0", node.input(0));
+    }
+  }
+  EXPECT_EQ(9, count);
+
+  item.fetch = {"outputs"};
+  item.feed.emplace_back("a", MakeScalarTensor(2.0f));
+  auto tensors_expected = EvaluateFetchNodes(item);
+
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+
+  test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
 TEST_F(FunctionOptimizerTest, SymbolicGradients) {
   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();