Made sure all the nodes in the body of an inlined function run in the same frame
authorBenoit Steiner <bsteiner@google.com>
Wed, 7 Mar 2018 04:12:55 +0000 (20:12 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 04:16:34 +0000 (20:16 -0800)
PiperOrigin-RevId: 188121852

tensorflow/core/grappler/optimizers/function_optimizer.cc
tensorflow/core/grappler/optimizers/function_optimizer_test.cc

index 4b830bc..d8a237c 100644 (file)
@@ -78,10 +78,16 @@ Status InlineFunction(const NodeDef& node, const FunctionDef& func,
       func_body_node.add_input(
           strings::StrCat(func_inputs->name(), ":", input_id));
     } else {
-      // Update the input names.
+      // Update the input names if any.
       for (string& input : *func_body_node.mutable_input()) {
         input = AddPrefixToNodeName(input, node.name());
       }
+      // If the node has no input, make hook it up to the func_inputs node to
+      // ensure it runs in the same frame as the other nodes of the function
+      // body.
+      if (func_body_node.input_size() == 0) {
+        *func_body_node.add_input() = AsControlDependency(func_inputs->name());
+      }
     }
 
     // Add the node name as a prefix to avoid collisions after inlining
index 8db9b7f..bafcdf4 100644 (file)
@@ -63,6 +63,8 @@ TEST_F(FunctionOptimizerTest, SimpleFunction) {
       count++;
       EXPECT_EQ("Const", node.op());
       EXPECT_EQ(device, node.device());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^y/inlined_inputs", node.input(0));
     } else if (node.name() == "y/scale") {
       count++;
       EXPECT_EQ("Cast", node.op());
@@ -153,6 +155,8 @@ TEST_F(FunctionOptimizerTest, FixedTypeFunction) {
     } else if (node.name() == "y/two") {
       count++;
       EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^y/inlined_inputs", node.input(0));
       EXPECT_EQ(device, node.device());
     } else if (node.name() == "y/y") {
       count++;