Properly handle the case of functions with no inputs
authorBenoit Steiner <bsteiner@google.com>
Sat, 3 Mar 2018 01:18:00 +0000 (17:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 3 Mar 2018 01:22:01 +0000 (17:22 -0800)
PiperOrigin-RevId: 187691555

tensorflow/core/grappler/optimizers/function_optimizer.cc
tensorflow/core/grappler/optimizers/function_optimizer_test.cc
tensorflow/core/grappler/utils/functions_test.cc

index 167e5a153a809450d8992834d5254e21017ee469..4b830bcc6e7891c9affacdf788280f3e1543afaa 100644 (file)
@@ -126,9 +126,17 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                                    GraphDef* optimized_graph) {
   std::unordered_map<string, const FunctionDef*> functions;
   for (const FunctionDef& func : item.graph.library().function()) {
-    if (func.attr().count("_noinline") == 0) {
-      functions[func.signature().name()] = &func;
+    // Don't inline functions marked as noinline
+    if (func.attr().count("_noinline") != 0) {
+      continue;
     }
+    // Can't create IdentityN nodes with no input or output: skip these
+    // functions for now.
+    if (func.signature().input_arg_size() == 0 ||
+        func.signature().output_arg_size() == 0) {
+      continue;
+    }
+    functions[func.signature().name()] = &func;
   }
 
   // Nothing to do.
index 5072abaac74e1b8dac98cd160e40dc094618514e..8db9b7f77adadc6a6404d34fbd63b9fa840c5006 100644 (file)
@@ -339,6 +339,40 @@ TEST_F(FunctionOptimizerTest, FunctionWithInputForwarding) {
   test::ExpectTensorEqual<int>(tensors_expected[2], tensors[2]);
 }
 
+TEST_F(FunctionOptimizerTest, FunctionWithoutInput) {
+  const Tensor kTwo = test::AsScalar<int64>(2);
+  FunctionDef func = FunctionDefHelper::Define(
+      // Name
+      "GenerateTwo",
+      // Args
+      {},
+      // Return value
+      {"o: T"},
+      // Attr def
+      {"T: {float, double}"},
+      // Nodes
+      {{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
+       {{"o"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}});
+
+  GrapplerItem item;
+  constexpr char device[] = "/device:CPU:0";
+  item.graph = test::function::GDef(
+      {test::function::NDef("y", "GenerateTwo", {}, {}, device),
+       test::function::NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, device)},
+      // FunctionLib
+      {
+          func,
+      });
+
+  FunctionOptimizer optimizer;
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  // For now we won't inline the function.
+  EXPECT_EQ(item.graph.DebugString(), output.DebugString());
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow
index 25ec50d4784fbb061d0b85eb84d47025de6ea975..6a7d766b1c6b49f8fc13b3b0294f3e3f8a74eb35 100644 (file)
@@ -308,6 +308,43 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
   }
 }
 
+TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
+  const Tensor kTwo = test::AsScalar<int64>(2);
+  FunctionDef func = FunctionDefHelper::Define(
+      // Name
+      "GenerateTwo",
+      // Args
+      {},
+      // Return value
+      {"o: T"},
+      // Attr def
+      {"T: {float, double}"},
+      // Nodes
+      {{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
+       {{"o"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}});
+
+  std::unordered_map<string, AttrValue> func_attr;
+  func_attr["T"].set_type(DT_FLOAT);
+  FunctionDefLibrary library;
+  std::unique_ptr<GrapplerItem> item =
+      GrapplerItemFromFunctionDef(func, func_attr, library);
+
+  EXPECT_EQ(0, item->feed.size());
+  EXPECT_EQ(1, item->fetch.size());
+  EXPECT_EQ("o:0", item->fetch[0]);
+
+  EXPECT_EQ(2, item->graph.node_size());
+  const NodeDef &two = item->graph.node(0);
+  EXPECT_EQ("two", two.name());
+  EXPECT_EQ(0, two.input_size());
+  const NodeDef &cast = item->graph.node(1);
+  EXPECT_EQ("o", cast.name());
+  EXPECT_EQ(1, cast.input_size());
+  EXPECT_EQ("two:0", cast.input(0));
+
+  std::cout << item->graph.DebugString() << std::endl;
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow