Replace GrapplerFunctionItem input with a constant.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 27 Apr 2018 19:46:46 +0000 (12:46 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 27 Apr 2018 19:49:07 +0000 (12:49 -0700)
PiperOrigin-RevId: 194579253

tensorflow/core/grappler/utils/functions.cc
tensorflow/core/grappler/utils/functions.h
tensorflow/core/grappler/utils/functions_test.cc

index 790809b..79b823f 100644 (file)
@@ -566,6 +566,60 @@ Status RegisterGrapplerFunctionConnectivity(
   return Status::OK();
 }
 
+Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
+                             GrapplerFunctionItem* item) {
+  if (!IsConstant(input_const)) {
+    return errors::InvalidArgument("Input node ", input_const.name(),
+                                   " is not a constant");
+  }
+
+  auto& inputs = item->input_arg_expansions_;
+
+  // Find input arg expansion and input placeholder position in it for the
+  // given function input position.
+  InputArgExpansion* input_arg_expansion = nullptr;
+  int placeholder_idx = input_position;
+
+  for (InputArgExpansion& input : inputs) {
+    if (placeholder_idx < input.placeholders.size()) {
+      input_arg_expansion = &input;
+      break;
+    }
+    placeholder_idx -= input.placeholders.size();
+  }
+
+  if (input_arg_expansion == nullptr) {
+    return errors::InvalidArgument(
+        "Input placeholder not found: input_position=", input_position,
+        " function=", item->id);
+  }
+
+  // Delete placeholder from input expansion.
+  string placeholder_name = input_arg_expansion->placeholders[placeholder_idx];
+  item->input_arg_placeholders_.erase(placeholder_name);
+  input_arg_expansion->placeholders.erase(
+      input_arg_expansion->placeholders.begin() + placeholder_idx);
+
+  // Delete empty input expansions.
+  inputs.erase(std::remove_if(inputs.begin(), inputs.end(),
+                              [](const InputArgExpansion& input) {
+                                return input.placeholders.empty();
+                              }),
+               inputs.end());
+
+  // Replace placeholder node in the function body with a const node.
+  for (NodeDef& node : *item->graph.mutable_node()) {
+    if (node.name() == placeholder_name) {
+      node = input_const;
+      node.set_name(placeholder_name);
+      node.clear_input();   // remove potential control inputs
+      node.clear_device();  // device placement is defined by instantiating node
+    }
+  }
+
+  return Status::OK();
+}
+
 Status MakeFunctionDef(const GrapplerFunctionItem& item,
                        const FunctionLibraryDefinition& flib,
                        FunctionDef* func) {
@@ -579,6 +633,9 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
 
   // Add function input arguments.
   for (const InputArgExpansion& input_arg : item.inputs()) {
+    CHECK(input_arg.placeholders.size() == 1)  // do some sanity checking
+        << "Inputs of tensor sequences are not supported";
+
     OpDef::ArgDef arg_def;
     arg_def.set_name(input_arg.input_name);
     arg_def.set_type(input_arg.data_type);
@@ -588,15 +645,15 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
 
   // Add function output arguments.
   for (const OutputArgExpansion& output_arg : item.outputs()) {
+    CHECK(output_arg.output_tensors.size() == 1)  // do some sanity checking
+        << "Outputs of tensor sequences are not supported";
+
     OpDef::ArgDef arg_def;
     arg_def.set_name(output_arg.output_name);
     arg_def.set_type(output_arg.data_type);
     arg_def.set_is_ref(output_arg.is_ref);
     *func->mutable_signature()->add_output_arg() = arg_def;
 
-    CHECK(output_arg.output_tensors.size() == 1)  // do some sanity checking
-        << "Outputs of tensor sequences are not supported";
-
     string ret;
     for (const string& output_tensor : output_arg.output_tensors) {
       TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(output_tensor, &ret));
index 692333f..d9d71b8 100644 (file)
@@ -162,6 +162,9 @@ class GrapplerFunctionItem : public GrapplerItem {
   GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
 
  private:
+  friend Status ReplaceInputWithConst(const NodeDef&, int,
+                                      GrapplerFunctionItem*);
+
   AttrValueMap func_attr_;  // Attributes specific to function definition that
                             // produced this item (FuncDef.attr field).
 
@@ -189,12 +192,16 @@ bool HasParametrizedBody(const FunctionDef& func);
 bool IsParametrized(const FunctionDef& func);
 
 // Register GrapplerFunctionItem input arg expansion and function body outputs
-// in the GrapplerFunctionConnectivity.  Use function library definition to
+// in the GrapplerFunctionConnectivity. Use function library definition to
 // lookup function body nodes output names and ranges.
 Status RegisterGrapplerFunctionConnectivity(
     const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
     GrapplerFunctionConnectivity* connectivity);
 
+// Replace one of the function inputs with a constant.
+Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
+                             GrapplerFunctionItem* item);
+
 // Make a GrapplerFunctionItem from the function definition and function
 // instantiation attributes (caller node attributes). Returns error if the given
 // function def cannot be converted (e.g. not all attributes are defined).
index 6dfd49b..fa6fec7 100644 (file)
@@ -573,6 +573,81 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
   EXPECT_EQ(2, count);
 }
 
+TEST_F(FunctionsTest, ReplaceInputWithConst) {
+  FunctionDef func = FunctionDefHelper::Create(
+      "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+      {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "output:z:0"}});
+
+  std::unordered_map<string, AttrValue> func_attr;
+  func_attr["T"].set_type(DT_FLOAT);
+  FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
+
+  GrapplerFunctionItem item;
+  TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+
+  EXPECT_EQ(2, item.input_size());
+  EXPECT_EQ(1, item.output_size());
+
+  ASSERT_EQ(3, item.function_body().node_size());
+
+  const NodeDef &input_x = item.function_body().node(0);
+  const NodeDef &input_y = item.function_body().node(1);
+
+  // Initially inputs added to the graph as placeholders.
+  EXPECT_EQ("Placeholder", input_x.op());
+  EXPECT_EQ("Placeholder", input_y.op());
+
+  // Replace inputs x and y with constants.
+  NodeDef const_input_x;
+  const_input_x.set_op("Const");
+  AddNodeAttr("Tag", "const_input_x", &const_input_x);
+
+  NodeDef const_input_y;
+  const_input_y.set_op("Const");
+  AddNodeAttr("Tag", "const_input_y", &const_input_y);
+
+  // Replace input x.
+  TF_EXPECT_OK(ReplaceInputWithConst(const_input_x, 0, &item));
+
+  EXPECT_EQ(1, item.input_size());
+  EXPECT_EQ("Const", input_x.op());
+  EXPECT_EQ("const_input_x", input_x.attr().at("Tag").s());
+
+  // Replace input y.
+  TF_EXPECT_OK(ReplaceInputWithConst(const_input_y, 0, &item));
+
+  EXPECT_EQ(0, item.input_size());
+  EXPECT_EQ("Const", input_y.op());
+  EXPECT_EQ("const_input_y", input_y.attr().at("Tag").s());
+
+  // Make a function from const-specialized function item.
+  FunctionDef specialized;
+  TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized));
+
+  EXPECT_EQ(0, specialized.signature().input_arg_size());
+  EXPECT_EQ(1, specialized.signature().output_arg_size());
+  EXPECT_EQ(3, specialized.node_def_size());
+
+  // Check that graph has const nodes pushed into function body.
+  int count = 0;
+  for (const NodeDef &node : specialized.node_def()) {
+    if (node.name() == "x" && count++) {
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ("const_input_x", node.attr().at("Tag").s());
+    } else if (node.name() == "y" && count++) {
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ("const_input_y", node.attr().at("Tag").s());
+    } else if (node.name() == "output" && count++) {
+      EXPECT_EQ("Mul", node.op());
+      EXPECT_EQ("x:output:0", node.input(0));
+      EXPECT_EQ("y:output:0", node.input(1));
+    }
+  }
+  EXPECT_EQ(3, count);
+}
+
 TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
   using test::function::NDef;