Clients should ideally prevent such functions from being created in
the first place, but we still want the runtime to be robust to
malformed functions. Trying to run functions with invalid control flow
constructs can result in crashes or hangs, so we want to catch it
before running.
PiperOrigin-RevId:
180727589
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/gradients.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/optimizer_cse.h"
InstantiationResult result;
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
- Graph* graph = new Graph(lib_def);
+ std::unique_ptr<Graph> graph(new Graph(lib_def));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
- Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
- if (!s.ok()) {
- delete graph;
- } else {
- *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types, graph);
- }
- return s;
+ TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
+
+ // Call BuildControlFlowInfo to validate that this function body has
+ // well-formed control flow.
+ // NOTE(skyewm): this is usually done in Partition(), but we don't partition
+ // function bodies. This should be removed if function bodies ever go through
+ // the Partition() path.
+ std::vector<ControlFlowInfo> dummy;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
+
+ *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types,
+ graph.release());
+ return Status::OK();
}
} // end namespace tensorflow
"type attr not found");
}
+TEST_F(FunctionLibraryRuntimeTest, Error_BadControlFlow) {
+ Init({test::function::InvalidControlFlow()});
+ auto x = test::AsTensor<int32>({0});
+ DCHECK_EQ(x.dtype(), DT_INT32);
+ Tensor y;
+ HasError(InstantiateAndRun(flr0_, "InvalidControlFlow", {}, {x}, {&y}),
+ "The node 'add' has inputs from different frames. The input 'enter' "
+ "is in frame 'while'. The input 'i' is in frame ''.");
+}
+
TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
{{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
}
+FunctionDef InvalidControlFlow() {
+ return FDH::Create(
+ // Name
+ "InvalidControlFlow",
+ // Args
+ {"i: int32"},
+ // Return values
+ {"o: int32"},
+ // Attr def
+ {},
+ // Nodes
+ {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}},
+ {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}},
+ // Output mapping
+ {{"o", "add:z"}});
+}
+
void FunctionTestSchedClosure(std::function<void()> fn) {
static thread::ThreadPool* w =
new thread::ThreadPool(Env::Default(), "Test", 8);
// x:T, y:T -> y:T, x:T
FunctionDef Swap();
+// Contains malformed control flow which can't be run by the executor.
+FunctionDef InvalidControlFlow();
+
void FunctionTestSchedClosure(std::function<void()> fn);
} // end namespace function
// Assign to each node the name of the frame and the level it belongs to.
// We check the well-formedness of the graph: All inputs to a node must
// come from the same frame and have the same "static" iteration level.
+// `info` is cleared and populated by this function.
// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level
// 0. This essentially means there can't be multiple serial Nexts in
// an iteration, which all sane front-ends should satisfy.