Run C++ control flow validation on FunctionDefs before running.
authorSkye Wanderman-Milne <skyewm@google.com>
Wed, 3 Jan 2018 23:40:03 +0000 (15:40 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 3 Jan 2018 23:43:22 +0000 (15:43 -0800)
Clients should ideally prevent such functions from being created in
the first place, but we still want the runtime to be robust to
malformed functions. Trying to run functions with invalid control flow
constructs can result in crashes or hangs, so we want to catch it
before running.

PiperOrigin-RevId: 180727589

tensorflow/core/common_runtime/function.cc
tensorflow/core/common_runtime/function_test.cc
tensorflow/core/framework/function_testlib.cc
tensorflow/core/framework/function_testlib.h
tensorflow/core/graph/control_flow.h

index b921cbcafca3569909e493cc608f4ecdb2a50d75..51d7f98f727fdf04aedc1b29f3e2eeed93300493 100644 (file)
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/versions.pb.h"
 #include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
 #include "tensorflow/core/graph/gradients.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/optimizer_cse.h"
@@ -1509,17 +1510,23 @@ Status FunctionDefToBodyHelper(
   InstantiationResult result;
   TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
 
-  Graph* graph = new Graph(lib_def);
+  std::unique_ptr<Graph> graph(new Graph(lib_def));
   GraphConstructorOptions opts;
   opts.allow_internal_ops = true;
   opts.expect_device_spec = false;
-  Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
-  if (!s.ok()) {
-    delete graph;
-  } else {
-    *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types, graph);
-  }
-  return s;
+  TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
+
+  // Call BuildControlFlowInfo to validate that this function body has
+  // well-formed control flow.
+  // NOTE(skyewm): this is usually done in Partition(), but we don't partition
+  // function bodies. This should be removed if function bodies ever go through
+  // the Partition() path.
+  std::vector<ControlFlowInfo> dummy;
+  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
+
+  *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types,
+                            graph.release());
+  return Status::OK();
 }
 
 }  // end namespace tensorflow
index 7b553c2dcde43b1f4442bb85d2d8ad98aae144ec..d4181ff48c32ab744cbc4d0241b7e294eec115bb 100644 (file)
@@ -783,6 +783,16 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
            "type attr not found");
 }
 
+TEST_F(FunctionLibraryRuntimeTest, Error_BadControlFlow) {
+  Init({test::function::InvalidControlFlow()});
+  auto x = test::AsTensor<int32>({0});
+  DCHECK_EQ(x.dtype(), DT_INT32);
+  Tensor y;
+  HasError(InstantiateAndRun(flr0_, "InvalidControlFlow", {}, {x}, {&y}),
+           "The node 'add' has inputs from different frames. The input 'enter' "
+           "is in frame 'while'. The input 'i' is in frame ''.");
+}
+
 TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
   Init({test::function::XTimesTwo(), test::function::XTimesFour(),
         test::function::XTimes16()});
index f8b456051b76241104febd29d55fe82a9146a239..afd5ad4f2f8afabf4bbe58ba2904096046fe72b7 100644 (file)
@@ -193,6 +193,23 @@ FunctionDef Swap() {
        {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
 }
 
+FunctionDef InvalidControlFlow() {
+  return FDH::Create(
+      // Name
+      "InvalidControlFlow",
+      // Args
+      {"i: int32"},
+      // Return values
+      {"o: int32"},
+      // Attr def
+      {},
+      // Nodes
+      {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}},
+       {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}},
+      // Output mapping
+      {{"o", "add:z"}});
+}
+
 void FunctionTestSchedClosure(std::function<void()> fn) {
   static thread::ThreadPool* w =
       new thread::ThreadPool(Env::Default(), "Test", 8);
index fbf273fa015c9326e01f45d1c603d22ab239fe25..b67c5cb1ab94f9e203f99b2a5982e282c76f942c 100644 (file)
@@ -81,6 +81,9 @@ FunctionDef NonZero();
 // x:T, y:T -> y:T, x:T
 FunctionDef Swap();
 
+// Contains malformed control flow which can't be run by the executor.
+FunctionDef InvalidControlFlow();
+
 void FunctionTestSchedClosure(std::function<void()> fn);
 
 }  // end namespace function
index 22dbb47010729d61547b33db3a6c8b0ad4fefdb4..372044f538f9428e1979ba80bbb18a9742fc014e 100644 (file)
@@ -33,6 +33,7 @@ struct ControlFlowInfo {
 // Assign to each node the name of the frame and the level it belongs to.
 // We check the well-formedness of the graph: All inputs to a node must
 // come from the same frame and have the same "static" iteration level.
+// `info` is cleared and populated by this function.
 // NOTE(yuanbyu): For now, we require all sends/recvs have iteration level
 // 0. This essentially means there can't be multiple serial Nexts in
 // an iteration, which all sane front-ends should satisfy.