Convert GrapplerFunctionItem to (Specialized)FunctionDef.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 01:29:05 +0000 (18:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 01:31:28 +0000 (18:31 -0700)
PiperOrigin-RevId: 192704808

tensorflow/core/grappler/utils/BUILD
tensorflow/core/grappler/utils/functions.cc
tensorflow/core/grappler/utils/functions.h
tensorflow/core/grappler/utils/functions_test.cc

index 05d9cba..b473f32 100644 (file)
@@ -165,6 +165,7 @@ cc_library(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
     ],
 )
@@ -177,6 +178,8 @@ tf_cc_test(
         "//tensorflow/cc:cc_ops",
         "//tensorflow/core:all_kernels",
         "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
index dd0d918..e8d423a 100644 (file)
@@ -23,27 +23,82 @@ limitations under the License.
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/strings/scanner.h"
 
 namespace tensorflow {
 namespace grappler {
 
+namespace {
+
+Status OutputNameRange(const FunctionLibraryDefinition& flib,
+                       const NodeDef& node,
+                       tensorflow::NameRangeMap* outputs_range_map) {
+  const OpRegistrationData* registration;
+  TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
+  TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(node, registration->op_def,
+                                                   nullptr, outputs_range_map));
+  return Status::OK();
+}
+
+Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib,
+                                   const NodeDef& node,
+                                   GrapplerFunctionConnectivity* connectivity) {
+  tensorflow::NameRangeMap outputs_range_map;
+  TF_RETURN_IF_ERROR(OutputNameRange(flib, node, &outputs_range_map));
+  connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+  return Status::OK();
+}
+
+// Replace the placeholder attribute values with the values specified in
+// instantiation attributes.
+Status ResolveFunctionBodyNodeAttrPlaceholders(
+    const AttrValueMap& func_instantiation_attr, NodeDef* node) {
+  for (auto& attr : *node->mutable_attr()) {
+    const string& placeholder = attr.second.placeholder();
+    if (placeholder.empty()) continue;
+
+    auto it = func_instantiation_attr.find(placeholder);
+    if (it != func_instantiation_attr.end()) {
+      attr.second = it->second;
+    } else {
+      return errors::InvalidArgument("Can't resolve placeholder: ",
+                                     placeholder);
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
 void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
     const InputArgExpansion& input_arg_expansion) {
-  input_arg_expansions_.insert(
-      {input_arg_expansion.input_name, input_arg_expansion});
+  const auto& input_name = input_arg_expansion.input_name;
+  const auto& placeholders = input_arg_expansion.placeholders;
+  input_arg_expansions_.emplace(input_name, input_arg_expansion);
+  for (int i = 0; i < placeholders.size(); ++i) {
+    const string& placeholder = input_arg_expansion.placeholders[i];
+    input_arg_placeholders_.emplace(
+        placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+  }
 }
 
 void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
     const string& node_name, const tensorflow::NameRangeMap& outputs) {
-  function_body_outputs_.insert({node_name, outputs});
+  function_body_outputs_[node_name] = outputs;
 }
 
 Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
     const string& func_def_input, std::vector<string>* graph_def_inputs) const {
   using ::tensorflow::strings::Scanner;
 
+  if (IsControlInput(func_def_input)) {
+    graph_def_inputs->push_back(func_def_input);
+    return Status::OK();
+  }
+
   // Parse input format: "node_name[:node_output][:position]"
   string node_name;
   string node_output;
@@ -150,11 +205,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
   std::vector<string> expanded_inputs;
 
   for (const string& function_def_input : function_body_node->input()) {
-    if (!IsControlInput(function_def_input))
-      TF_RETURN_IF_ERROR(
-          ExpandFunctionDefInput(function_def_input, &expanded_inputs));
-    else
-      expanded_inputs.push_back(function_def_input);
+    TF_RETURN_IF_ERROR(
+        ExpandFunctionDefInput(function_def_input, &expanded_inputs));
   }
 
   function_body_node->clear_input();
@@ -163,10 +215,66 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
   return Status::OK();
 }
 
-Status GrapplerFunctionItemBuilder::GetTypeAttr(const string& type_attr_name,
-                                                DataType* data_type) const {
-  auto it = func_attr_->find(type_attr_name);
-  if (it == func_attr_->end()) {
+Status GrapplerFunctionConnectivity::AsFunctionDefInput(
+    const string& graph_def_input, string* func_def_input) const {
+  using gtl::FindOrNull;
+
+  if (IsControlInput(graph_def_input)) {
+    *func_def_input = graph_def_input;
+    return Status::OK();
+  }
+
+  int position;
+  string node_name = ParseNodeName(graph_def_input, &position);
+  CHECK_GE(position, 0);
+
+  // Check if it's an input arg placeholder
+  if (position == 0) {
+    const InputArgPlaceholder* placeholder =
+        FindOrNull(input_arg_placeholders_, node_name);
+    if (placeholder != nullptr) {
+      *func_def_input =
+          strings::StrCat(placeholder->input_name, ":", placeholder->position);
+      return Status::OK();
+    }
+  }
+
+  // It must be output from one of the function body nodes
+  const tensorflow::NameRangeMap* outputs_range_map =
+      FindOrNull(function_body_outputs_, node_name);
+  if (outputs_range_map != nullptr) {
+    for (const auto& el : *outputs_range_map) {
+      const auto& output_name = el.first;
+      const auto& output_range = el.second;
+      if (position >= output_range.first && position < output_range.second) {
+        int pos = position - output_range.first;
+        *func_def_input =
+            strings::StrCat(node_name, ":", output_name, ":", pos);
+        return Status::OK();
+      }
+    }
+  }
+
+  return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
+}
+
+Status GrapplerFunctionConnectivity::AsFunctionDefNode(
+    NodeDef* function_body_node) const {
+  string func_def_input;
+
+  for (int i = 0; i < function_body_node->input_size(); ++i) {
+    TF_RETURN_IF_ERROR(
+        AsFunctionDefInput(function_body_node->input(i), &func_def_input));
+    function_body_node->set_input(i, func_def_input);
+  }
+
+  return Status::OK();
+}
+
+Status GrapplerFunctionItemInstantiation::GetTypeAttr(
+    const string& type_attr_name, DataType* data_type) const {
+  auto it = func_instantiation_attr_->find(type_attr_name);
+  if (it == func_instantiation_attr_->end()) {
     return errors::InvalidArgument("Type attribute ", type_attr_name,
                                    " is not defined");
   } else if (it->second.type() == DT_INVALID) {
@@ -178,31 +286,48 @@ Status GrapplerFunctionItemBuilder::GetTypeAttr(const string& type_attr_name,
   return Status::OK();
 }
 
-Status GrapplerFunctionItemBuilder::GetArgType(const OpDef::ArgDef& arg,
-                                               DataType* data_type) const {
+Status GrapplerFunctionItemInstantiation::GetArgType(
+    const OpDef::ArgDef& arg, DataType* data_type) const {
   if (arg.type() != DT_INVALID) {
     *data_type = arg.type();
   } else {
+    if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) {
+      return errors::InvalidArgument(
+          "Arguments with sequence of tensors are not supported. Unsupported "
+          "argument name: ",
+          arg.name());
+    }
     TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type));
   }
   return Status::OK();
 }
 
 GrapplerFunctionItem::GrapplerFunctionItem(
-    const string& function_name,
+    const string& func_name, const AttrValueMap& func_attr,
     const std::vector<InputArgExpansion>& input_arg_expansions,
     const std::vector<OutputArgExpansion>& output_arg_expansions,
     GraphDef&& function_body)
-    : function_name_(function_name),
+    : func_attr_(func_attr),
       input_arg_expansions_(input_arg_expansions),
       output_arg_expansions_(output_arg_expansions) {
+  id = func_name;
+  // Fill the feed nodes with input placeholders
+  for (const InputArgExpansion& input_arg : input_arg_expansions_) {
+    for (const string& placeholder : input_arg.placeholders) {
+      feed.emplace_back(placeholder, Tensor());
+      input_arg_placeholders_.insert(placeholder);
+    }
+  }
+  // Fill the fetch nodes with outputs
+  for (const OutputArgExpansion& output_arg : output_arg_expansions_) {
+    for (const string& output_tensor : output_arg.output_tensors) {
+      fetch.push_back(output_tensor);
+    }
+  }
+  // Swap the graph body
   graph.Swap(&function_body);
 }
 
-const string& GrapplerFunctionItem::function_name() const {
-  return function_name_;
-}
-
 const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const {
   return input_arg_expansions_;
 }
@@ -215,6 +340,11 @@ const std::size_t GrapplerFunctionItem::input_size() const {
   return input_arg_expansions_.size();
 }
 
+bool GrapplerFunctionItem::IsInputPlaceholder(const string& node_name) const {
+  return input_arg_placeholders_.find(node_name) !=
+         input_arg_placeholders_.end();
+}
+
 const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const {
   return output_arg_expansions_;
 }
@@ -227,10 +357,19 @@ const std::size_t GrapplerFunctionItem::output_size() const {
   return output_arg_expansions_.size();
 }
 
+const AttrValueMap& GrapplerFunctionItem::func_attr() const {
+  return func_attr_;
+}
+
 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
 
 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
 
+GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
+  graph.Swap(&other);
+  return *this;
+}
+
 std::vector<string> OutputTensors(const GrapplerFunctionItem& item) {
   std::vector<string> output_tensors;
   for (const OutputArgExpansion& output : item.outputs()) {
@@ -241,18 +380,27 @@ std::vector<string> OutputTensors(const GrapplerFunctionItem& item) {
   return output_tensors;
 }
 
-Status MakeGrapplerFunctionItem(
-    const FunctionDef& func,
-    const std::unordered_map<string, AttrValue>& func_attr,
-    const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item) {
+Status MakeGrapplerFunctionItem(const FunctionDef& func,
+                                const AttrValueMap& func_instantiation_attr,
+                                const FunctionLibraryDefinition& flib,
+                                GrapplerFunctionItem* item) {
   const OpDef& signature = func.signature();
 
   if (signature.name().empty()) {
     return errors::InvalidArgument("Function name must be specified");
   }
 
-  // Helper methods to lookup function attributes
-  GrapplerFunctionItemBuilder builder(&func_attr);
+  // Function types will be resolved from function instantiation attributes. All
+  // other attributes will be lost during conversion to FunctionDef.
+  for (const OpDef::AttrDef& attr : signature.attr()) {
+    if (attr.type() != "type") {
+      return errors::InvalidArgument(
+          "Function signature must have only type attributes");
+    }
+  }
+
+  // Helper methods to lookup function instantiation attributes
+  GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr);
 
   // Mapping from FunctionDef input format (name[:output][:position]) to
   // GraphDef input format (name[:position])
@@ -260,7 +408,10 @@ Status MakeGrapplerFunctionItem(
 
   std::vector<InputArgExpansion> inputs;
   std::vector<OutputArgExpansion> outputs;
+
+  // Function body shares the library with the graph that instantiated it.
   GraphDef function_body;
+  *function_body.mutable_library() = flib.ToProto();
 
   // TODO(ezhulenev): support functions with tensor sequence inputs/outputs
 
@@ -284,7 +435,7 @@ Status MakeGrapplerFunctionItem(
     }
 
     DataType input_data_type;
-    TF_RETURN_IF_ERROR(builder.GetArgType(input, &input_data_type));
+    TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type));
 
     NodeDef* placeholder = function_body.add_node();
     placeholder->set_name(input.name());
@@ -292,6 +443,7 @@ Status MakeGrapplerFunctionItem(
     (*placeholder->mutable_attr())["T"].set_type(input_data_type);
 
     InputArgExpansion input_expansion{/*input_name=*/input.name(),
+                                      /*data_type=*/input_data_type,
                                       /*placeholders=*/{input.name()}};
     connectivity.RegisterInputArgExpansion(input_expansion);
     inputs.push_back(input_expansion);
@@ -302,24 +454,12 @@ Status MakeGrapplerFunctionItem(
     NodeDef* new_node = function_body.add_node();
     *new_node = func_def_node;
 
-    // Replace the placeholder attribute values with the specified value
-    for (auto& attr : *new_node->mutable_attr()) {
-      const string& ph_name = attr.second.placeholder();
-      auto it = func_attr.find(ph_name);
-      if (it != func_attr.end()) {
-        attr.second = it->second;
-      }
-    }
-
-    // Functions use a custom format to encode connectivity. Map these custom
-    // strings to regular ones.
-    tensorflow::NameRangeMap outputs_range_map;
-    const OpRegistrationData* registration;
-    TF_RETURN_IF_ERROR(func_library.LookUp(func_def_node.op(), &registration));
-    TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
-        func_def_node, registration->op_def, nullptr, &outputs_range_map));
-    connectivity.RegisterFunctionBodyOutputs(func_def_node.name(),
-                                             outputs_range_map);
+    // Resolve all placeholder values using function instantiation attributes.
+    TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders(
+        func_instantiation_attr, new_node));
+    // Register node output range in a function connectivity.
+    TF_RETURN_IF_ERROR(
+        RegisterFunctionBodyOutputs(flib, func_def_node, &connectivity));
   }
 
   // Rewrite inputs to use GraphDef format
@@ -331,20 +471,96 @@ Status MakeGrapplerFunctionItem(
   for (const OpDef::ArgDef& out : signature.output_arg()) {
     std::vector<string> output_tensors;
     auto ret = func.ret().find(out.name());
-    if (ret != func.ret().end()) {
-      // Expand outputs using provided output mapping
-      TF_RETURN_IF_ERROR(
-          connectivity.ExpandFunctionDefInput(ret->second, &output_tensors));
-    } else {
-      // Otherwise output must be one of the function inputs
-      TF_RETURN_IF_ERROR(
-          connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
+    TF_RETURN_IF_ERROR(
+        ret != func.ret().end()
+            // Expand outputs using provided output mapping
+            ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)
+            // Otherwise output must be one of the function inputs
+            : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
+
+    DataType output_data_type;
+    TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type));
+
+    OutputArgExpansion output{/*output_name=*/out.name(),
+                              /*data_type=*/output_data_type,
+                              /*output_tensors=*/output_tensors};
+    outputs.push_back(output);
+  }
+
+  *item = GrapplerFunctionItem(
+      /*func_name=*/signature.name(),
+      /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
+      inputs, outputs, std::move(function_body));
+  return Status::OK();
+}
+
+// Register GrapplerFunctionItem input arg expansion and function body outputs
+// in the GrapplerFunctionConnectivity
+Status RegisterGrapplerFunctionConnectivity(
+    const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
+    GrapplerFunctionConnectivity* connectivity) {
+  for (const InputArgExpansion& input : item.inputs()) {
+    connectivity->RegisterInputArgExpansion(input);
+  }
+  for (const NodeDef& func_body_node : item.function_body().node()) {
+    TF_RETURN_IF_ERROR(
+        RegisterFunctionBodyOutputs(flib, func_body_node, connectivity));
+  }
+  return Status::OK();
+}
+
+Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item,
+                                  const FunctionLibraryDefinition& flib,
+                                  FunctionDef* func) {
+  func->mutable_signature()->set_name(item.id);
+
+  // Build a GrapplerFunctionConnectivity from inputs and new function body.
+  GrapplerFunctionConnectivity connectivity;
+  TF_RETURN_IF_ERROR(
+      RegisterGrapplerFunctionConnectivity(item, flib, &connectivity));
+
+  // Add function input arguments.
+  for (const InputArgExpansion& input_arg : item.inputs()) {
+    OpDef::ArgDef arg_def;
+    arg_def.set_name(input_arg.input_name);
+    arg_def.set_type(input_arg.data_type);
+    *func->mutable_signature()->add_input_arg() = arg_def;
+  }
+
+  // Add function output arguments.
+  for (const OutputArgExpansion& output_arg : item.outputs()) {
+    OpDef::ArgDef arg_def;
+    arg_def.set_name(output_arg.output_name);
+    arg_def.set_type(output_arg.data_type);
+    *func->mutable_signature()->add_output_arg() = arg_def;
+
+    CHECK(output_arg.output_tensors.size() == 1)  // do some sanity checking
+        << "Outputs of tensor sequences are not supported";
+
+    string ret;
+    for (const string& output_tensor : output_arg.output_tensors) {
+      TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(output_tensor, &ret));
+      (*func->mutable_ret())[output_arg.output_name] = ret;
     }
-    outputs.push_back({out.name(), output_tensors});
   }
 
-  *item = GrapplerFunctionItem(signature.name(), inputs, outputs,
-                               std::move(function_body));
+  // Copy function definition specific attributes.
+  for (const auto& attr : item.func_attr()) {
+    const auto& attr_name = attr.first;
+    const auto& attr_value = attr.second;
+    (*func->mutable_attr())[attr_name] = attr_value;
+  }
+
+  // Copy function body nodes to the FunctionDef and update input format
+  for (const NodeDef& func_body_node : item.function_body().node()) {
+    // Do not copy input placeholders
+    if (item.IsInputPlaceholder(func_body_node.name())) continue;
+
+    NodeDef* func_def_node = func->add_node_def();
+    *func_def_node = func_body_node;
+    TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node));
+  }
+
   return Status::OK();
 }
 
index 60ea885..2ac3917 100644 (file)
@@ -28,14 +28,19 @@ limitations under the License.
 namespace tensorflow {
 namespace grappler {
 
+using AttrValueMap = std::unordered_map<string, AttrValue>;
+
 // Depending on the function instantiation attributes, input argument to the
 // function might be a single tensor, list of tensors of the same type, or a
 // list of tensors of different types.
 //
 // InputArgExpansion keeps track of the placeholders that were added to the
-// function body in place of function inputs.
+// function body in place of function inputs and a resolved input data type.
 struct InputArgExpansion {
+  // TODO(ezhulenev): Add support for functions with tensor sequence inputs of
+  // different data types
   string input_name;                 // name of the function input argument
+  DataType data_type;                // input data type
   std::vector<string> placeholders;  // names of placeholder nodes in the
                                      // function body
 };
@@ -44,11 +49,14 @@ struct InputArgExpansion {
 // to one or more outputs of one of the function body nodes.
 //
 // OutputArgExpansion keeps mapping from a function output arg to the output
-// tensors of a function body nodes, that compute function outputs.
+// tensors of a function body nodes and a resolved output data type
 struct OutputArgExpansion {
+  // TODO(ezhulenev): Add support for functions with tensor sequence outputs of
+  // different data types
   string output_name;                  // name of the function output argument
-  std::vector<string> output_tensors;  // names of output tensors from the
-                                       // function body graph nodes
+  DataType data_type;                  // output data type
+  std::vector<string> output_tensors;  // names of output tensor from the
+                                       // function body nodes
 };
 
 // FunctionDef uses different connectivity encoding for the function body nodes,
@@ -67,26 +75,46 @@ class GrapplerFunctionConnectivity {
   Status ExpandFunctionDefInput(const string& func_def_input,
                                 std::vector<string>* graph_def_inputs) const;
 
-  // Update Node inputs from FunctionDef to GraphDef format
+  // Update Node inputs from FunctionDef to GraphDef format.
   Status ExpandNodeInputs(NodeDef* function_body_node) const;
 
-  // TODO(ezhulenev): fold GraphDef inputs back to FunctionDef format
-  // Status FoldGraphDefInputs(const std::vector<sting> graph_def_inputs,
-  //                          std::vector<string>* function_def_inputs) const;
+  // When expanding inputs in function def format, single input might be
+  // expanded into multiple tensors. When converting back to the function def
+  // format from graph def format, it's always a 1-to-1 relationship.
+  // FunctionDef built from GrapplerFunctionItem is always specialized to it's
+  // instantiation attributes and length of input args (and node def outputs) is
+  // known.
+
+  // Map from GraphDef input format to FunctionDef input format using registered
+  // input arg expansion and function body outputs.
+  Status AsFunctionDefInput(const string& graph_def_input,
+                            string* func_def_input) const;
+
+  // Update Node inputs from GraphDef to FunctionDef format.
+  Status AsFunctionDefNode(NodeDef* function_body_node) const;
 
  private:
+  // Mapping from input name to input arg expansion.
   std::unordered_map<string, InputArgExpansion> input_arg_expansions_;
+  // Mapping from function body node name to output names range map.
   std::unordered_map<string, tensorflow::NameRangeMap> function_body_outputs_;
+
+  struct InputArgPlaceholder {
+    string input_name;
+    int position;
+  };
+
+  // Mapping from input arg placeholder to the function input tensor.
+  std::unordered_map<string, InputArgPlaceholder> input_arg_placeholders_;
 };
 
-// Helper methods to build GrapplerFunctionItem from a function def and function
-// attributes.
-class GrapplerFunctionItemBuilder {
+// Get Function type attributes using attributes of a node that instantiated
+// a function.
+class GrapplerFunctionItemInstantiation {
  public:
-  using FunctionAttr = std::unordered_map<string, AttrValue>;
-
-  explicit GrapplerFunctionItemBuilder(const FunctionAttr* func_attr)
-      : func_attr_(func_attr) {}
+  explicit GrapplerFunctionItemInstantiation(
+      const AttrValueMap* func_instantiation_attr)
+      : func_instantiation_attr_(func_instantiation_attr) {}
 
   // Get DataType from attributes by name. Return error if attribute is missing,
   // or it doesn't define a valid data type.
@@ -97,20 +125,20 @@ class GrapplerFunctionItemBuilder {
   Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const;
 
  private:
-  const FunctionAttr* func_attr_;  // do not own
+  const AttrValueMap* func_instantiation_attr_;  // do not own
 };
 
 // A special case of GrapplerItem, constructed from a TensorFlow Function.
 class GrapplerFunctionItem : public GrapplerItem {
  public:
-  GrapplerFunctionItem() {}
+  GrapplerFunctionItem() = default;
   GrapplerFunctionItem(
-      const string& function_name,
+      const string& func_name, const AttrValueMap& func_attr,
       const std::vector<InputArgExpansion>& input_arg_expansions,
       const std::vector<OutputArgExpansion>& output_arg_expansions,
       GraphDef&& function_body);
 
-  const string& function_name() const;
+  bool IsInputPlaceholder(const string& node_name) const;
 
   const std::vector<InputArgExpansion>& inputs() const;
   const InputArgExpansion& input(int i) const;
@@ -120,13 +148,20 @@ class GrapplerFunctionItem : public GrapplerItem {
   const OutputArgExpansion& output(int i) const;
   const std::size_t output_size() const;
 
+  const AttrValueMap& func_attr() const;
   const GraphDef& function_body() const;
   GraphDef& mutable_function_body();
 
+  GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
+
  private:
-  string function_name_;
+  AttrValueMap func_attr_;  // Attributes specific to function definition that
+                            // produced this item (FuncDef.attr field).
+
   std::vector<InputArgExpansion> input_arg_expansions_;
   std::vector<OutputArgExpansion> output_arg_expansions_;
+
+  std::set<string> input_arg_placeholders_;
 };
 
 // Return all output tensors referenced by item output args.
@@ -136,8 +171,21 @@ std::vector<string> OutputTensors(const GrapplerFunctionItem& item);
 // Return error if the given function def cannot be converted.
 Status MakeGrapplerFunctionItem(
     const FunctionDef& func,
-    const std::unordered_map<string, AttrValue>& func_attr,
-    const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item);
+    const std::unordered_map<string, AttrValue>& func_instantiation_attr,
+    const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item);
+
+// Register GrapplerFunctionItem input arg expansion and function body outputs
+// in the GrapplerFunctionConnectivity.  Use function library definition to
+// lookup function body nodes output names and ranges.
+Status RegisterGrapplerFunctionConnectivity(
+    const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
+    GrapplerFunctionConnectivity* connectivity);
+
+// Make a specialized FunctionDef from the GrapplerFunctionItem. Use function
+// library definition to lookup function body nodes output names and ranges.
+Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item,
+                                  const FunctionLibraryDefinition& flib,
+                                  FunctionDef* func);
 
 }  // end namespace grappler
 }  // end namespace tensorflow
index 1eb3298..a9a708b 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/meta_graph.pb.h"
 
@@ -32,8 +33,9 @@ class FunctionsTest : public ::testing::Test {};
 TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
   GrapplerFunctionConnectivity connectivity;
 
-  connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}});
-  connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}});
+  connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}});
+  connectivity.RegisterInputArgExpansion(
+      {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}});
 
   connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
   connectivity.RegisterFunctionBodyOutputs("Func",
@@ -93,11 +95,50 @@ TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
   EXPECT_EQ("Func:3", inputs[0]);
 }
 
+TEST_F(FunctionsTest, GrapplerFunctionConnectivity_AsFunctionDefInput) {
+  GrapplerFunctionConnectivity connectivity;
+
+  connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}});
+  connectivity.RegisterInputArgExpansion(
+      {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}});
+
+  connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
+  connectivity.RegisterFunctionBodyOutputs("Func",
+                                           {{"o1", {0, 2}}, {"o2", {2, 4}}});
+
+  string input;
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputA", &input));
+  EXPECT_EQ("inputA:0", input);
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_0", &input));
+  EXPECT_EQ("inputB:0", input);
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_1", &input));
+  EXPECT_EQ("inputB:1", input);
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("Add", &input));
+  EXPECT_EQ("Add:z:0", input);
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func", &input));
+  EXPECT_EQ("Func:o1:0", input);
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:1", &input));
+  EXPECT_EQ("Func:o1:1", input);
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:2", &input));
+  EXPECT_EQ("Func:o2:0", input);
+
+  TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:3", &input));
+  EXPECT_EQ("Func:o2:1", input);
+}
+
 TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandNodeInputs) {
   GrapplerFunctionConnectivity connectivity;
 
-  connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}});
-  connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}});
+  connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}});
+  connectivity.RegisterInputArgExpansion(
+      {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}});
 
   NodeDef node;
   node.add_input("inputA:0");
@@ -131,12 +172,12 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
 
   std::unordered_map<string, AttrValue> func_attr;
   func_attr["T"].set_type(DT_FLOAT);
-  FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
 
   GrapplerFunctionItem item;
-  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
 
-  EXPECT_EQ("XTimesTwo", item.function_name());
+  EXPECT_EQ("XTimesTwo", item.id);
   EXPECT_EQ(4, item.function_body().node_size());
 
   EXPECT_EQ(1, item.input_size());
@@ -206,12 +247,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
 
   std::unordered_map<string, AttrValue> func_attr;
   func_attr["T"].set_type(DT_FLOAT);
-  FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
 
   GrapplerFunctionItem item;
-  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
 
-  EXPECT_EQ("SubGrad", item.function_name());
+  EXPECT_EQ("SubGrad", item.id);
   EXPECT_EQ(12, item.function_body().node_size());
 
   ASSERT_EQ(3, item.input_size());
@@ -251,8 +292,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
 }
 
 TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
-  FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
-  TF_ASSERT_OK(library.AddFunctionDef(FunctionDefHelper::Define(
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
+  TF_ASSERT_OK(flib.AddFunctionDef(FunctionDefHelper::Define(
       // Name
       "Swap",
       // Args
@@ -290,7 +331,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
   func_attr["T"].set_type(DT_FLOAT);
 
   GrapplerFunctionItem item;
-  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
 
   int count = 0;
   for (const NodeDef &node : item.function_body().node()) {
@@ -348,10 +389,10 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
       {{"out", "Exp:y:0"}});
 
   std::unordered_map<string, AttrValue> func_attr;
-  FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
 
   GrapplerFunctionItem item;
-  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
 
   EXPECT_EQ(1, item.output_size());
   EXPECT_EQ("Exp", item.output(0).output_tensors[0]);
@@ -391,12 +432,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
       {{"out0", "in0"}});
 
   std::unordered_map<string, AttrValue> func_attr;
-  FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
 
   GrapplerFunctionItem item;
-  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
 
-  EXPECT_EQ("ForwardInputs", item.function_name());
+  EXPECT_EQ("ForwardInputs", item.id);
   EXPECT_EQ(5, item.function_body().node_size());
 
   EXPECT_EQ(3, item.output_size());
@@ -437,10 +478,10 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
 
   std::unordered_map<string, AttrValue> func_attr;
   func_attr["T"].set_type(DT_FLOAT);
-  FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
 
   GrapplerFunctionItem item;
-  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
 
   EXPECT_EQ(0, item.input_size());
   EXPECT_EQ(1, item.output_size());
@@ -456,6 +497,104 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
   EXPECT_EQ("two", cast.input(0));
 }
 
+TEST_F(FunctionsTest, MakeSpecializedFunctionDef) {
+  const Tensor kTwo = test::AsScalar<int64>(2);
+  FunctionDef func = FunctionDefHelper::Define(
+      // Name
+      "XTimesTwo",
+      // Args
+      {"x: T"},
+      // Return values
+      {"y: T"},
+      // Attr def
+      {"T: {float, double, int32, int64}"},
+      // Nodes
+      {
+          {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
+          {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+          {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
+      });
+
+  std::unordered_map<string, AttrValue> func_attr;
+  func_attr["T"].set_type(DT_FLOAT);
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
+
+  GrapplerFunctionItem item;
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+
+  FunctionDef specialized;
+  TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized));
+
+  // Input and output types are resolved based on instantiation attributes.
+  EXPECT_EQ("x", specialized.signature().input_arg(0).name());
+  EXPECT_EQ(DT_FLOAT, specialized.signature().input_arg(0).type());
+  EXPECT_EQ("y", specialized.signature().output_arg(0).name());
+  EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
+
+  // Function body specialized for instantiation types
+  int count = 0;
+  for (const NodeDef &node : specialized.node_def()) {
+    if (node.name() == "scale" && count++) {
+      EXPECT_EQ(DT_FLOAT, node.attr().at("DstT").type());
+    } else if (node.name() == "y" && count++) {
+      EXPECT_EQ("Mul", node.op());
+      EXPECT_EQ("x:0", node.input(0));
+      EXPECT_EQ("scale:y:0", node.input(1));
+      EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
+    }
+  }
+  EXPECT_EQ(2, count);
+}
+
+TEST_F(FunctionsTest, SwapFunctionBodyAndMakeSpecializedFunctionDef) {
+  using test::function::NDef;
+
+  FunctionDef mul_func = FunctionDefHelper::Create(
+      "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "output:z:0"}});
+
+  FunctionDef func = FunctionDefHelper::Create(
+      "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "output:z:0"}});
+
+  GraphDef id_func_body = test::function::GDef(
+      {/* pass input to output through identity */
+       NDef("output", "Identity", {"x"}, {{"T", "float"}})});
+
+  std::unordered_map<string, AttrValue> func_attr;
+  func_attr["T"].set_type(DT_FLOAT);
+
+  FunctionDefLibrary lib_def;
+  *lib_def.add_function() = func;
+  *lib_def.add_function() = mul_func;
+  FunctionLibraryDefinition flib(OpRegistry::Global(), lib_def);
+
+  GrapplerFunctionItem item;
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+
+  // Replace function body with identity function
+  item.SwapFunctionBody(std::move(id_func_body));
+  FunctionDef specialized;
+  TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized));
+
+  // Check that graph body was updated.
+  int count = 0;
+  for (const NodeDef &node : specialized.node_def()) {
+    if (node.name() == "output" && count++) {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ("x:0", node.input(0));
+    }
+  }
+  EXPECT_EQ(1, count);
+
+  // And return tensor mapping was updated with a new output name (z->output).
+  EXPECT_EQ("output:output:0", (*specialized.mutable_ret())["z"]);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow