return Status::OK();
}
-Status FunctionalizeLoop(Graph* graph, Frame* frame,
+// Copy the FunctionDef of given function from lookup_library to library, if
+// it can be found in lookup_library but is missing from library.
+Status AddMissingFunctionByName(const string& function_name,
+ const FunctionLibraryDefinition* lookup_library,
+ FunctionLibraryDefinition* library) {
+ if (!library->Find(function_name) && lookup_library->Find(function_name)) {
+ return library->AddFunctionDef(*lookup_library->Find(function_name));
+ }
+ return Status::OK();
+}
+
+// Iterate over all functions that the given fdef refers to. Copy the missing
+// FunctionDefs from lookup_library to library.
+Status AddMissingFunctionDef(const FunctionDef& fdef,
+ const FunctionLibraryDefinition* lookup_library,
+ FunctionLibraryDefinition* library) {
+ TF_RET_CHECK(lookup_library);
+ for (const NodeDef& node : fdef.node_def()) {
+ if (library->Find(node.op())) {
+ continue;
+ }
+ // The function refered by 'SymbolicGradient' node is specified in its
+ // attribute 'f'.
+ if (node.op() == FunctionLibraryDefinition::kGradientOp) {
+ const AttrValue* attr =
+ AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
+ if (!attr) {
+ return errors::InvalidArgument("SymbolicGradient is missing attr: f");
+ }
+ const string& func_name = attr->func().name();
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionByName(func_name, lookup_library, library));
+ // Copy the user-defined gradient function if it exists.
+ const string grad_name = lookup_library->FindGradient(func_name);
+ if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionByName(grad_name, lookup_library, library));
+ GradientDef grad_def;
+ grad_def.set_function_name(func_name);
+ grad_def.set_gradient_func(grad_name);
+ TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
+ }
+ } else if (lookup_library->Find(node.op())) {
+ TF_RETURN_IF_ERROR(
+ library->AddFunctionDef(*lookup_library->Find(node.op())));
+ }
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph, Frame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
<< dump_graph::DumpGraphToFile("functionalize_before", *graph,
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
+ if (lookup_library) {
+ // Copy missing FunctionDefs from lookup_library to library to make library
+ // self-contained.
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionDef(cond_fdef, lookup_library, library));
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionDef(body_fdef, lookup_library, library));
+ }
// Builds a While operator.
NodeDef while_def;
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
+ return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
+}
+
+Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph,
+ FunctionLibraryDefinition* library) {
VLOG(2) << "FunctionalizeControlFlow (initial): "
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph,
library);
continue;
}
- TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
+ TF_RETURN_IF_ERROR(
+ FunctionalizeLoop(lookup_library, graph, frame, library));
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
-// operators, suitable for XLA compilation.
+// operators, suitable for XLA compilation. If lookup_library is provided, use
+// it to make the library for control flow self-contained.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library);
+Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph,
+ FunctionLibraryDefinition* library);
} // namespace tensorflow
}
}
+// @function.Defun(noinline=True)
+// def increment_fn(x):
+// return [x + 1]
+// Define the above function, and add it to the given graph. It's used as the
+// while loop body in NoinlineLoopBody test.
+Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
+ FunctionDef fdef = FunctionDefHelper::Create(
+ "increment_fn", {"x:int32"}, {"add:int32"}, {},
+ {
+ {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
+ {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
+ },
+ {{"add", "add_0:z:0"}});
+ (*fdef.mutable_attr())["_noinline"].set_b(true);
+ FunctionDefLibrary fdef_lib;
+ *(fdef_lib.add_function()) = fdef;
+ TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
+ NodeDef increment_fn;
+ increment_fn.set_name(node_name);
+ increment_fn.set_op("increment_fn");
+ *increment_fn.add_input() = "while/Identity";
+ *increment_fn.add_input() = "^while/Identity";
+ Status status;
+ graph->AddNode(increment_fn, &status);
+ return status;
+}
+
+// Graph:
+// x = array_ops.placeholder(dtypes.int32)
+// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
+TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
+ const string& noinline_node_name = "while/increment_fn";
+ Graph graph(OpRegistry::Global());
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
+ auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
+ auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
+ "while/while_context");
+ auto merge = ops::Merge(scope.WithOpName("while/Merge"),
+ std::initializer_list<Input>{enter, dummy});
+ auto ten = ops::Const<int32>(
+ scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
+ 10);
+ auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
+ auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
+ auto switch_ =
+ ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
+ auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
+ switch_.output_false);
+ auto identity =
+ ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
+
+ TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
+
+ NodeDef next_iter;
+ next_iter.set_name("while/NextIteration");
+ next_iter.set_op("NextIteration");
+ *next_iter.add_input() = noinline_node_name;
+ (*next_iter.mutable_attr())["T"].set_type(DT_INT32);
+
+ Status status;
+ Node* n = scope.graph()->AddNode(next_iter, &status);
+ TF_ASSERT_OK(status);
+
+ // Remove the dummy node and add the loop backedge.
+ scope.graph()->RemoveNode(dummy.node());
+ scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
+ TF_ASSERT_OK(scope.ToGraph(&graph));
+ }
+
+ FunctionLibraryDefinition lookup_lib(graph.flib_def());
+ FunctionLibraryDefinition library(OpRegistry::Global(), {});
+ // Function increment_fn will be copied from lookup_lib to library.
+ TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library));
+
+ GraphDef graph_def;
+ graph.ToGraphDef(&graph_def);
+
+ NameAttrList cond_fn, body_fn;
+ TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
+
+ // Outer graph
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
+ auto while_op =
+ ops::XlaWhile(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
+ GraphDef expected;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected, graph_def);
+ }
+
+ // Body graph.
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
+ TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
+ auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
+ NodeDef retval;
+ retval.set_name("_retval0_RetVal");
+ retval.set_op(FunctionLibraryDefinition::kRetOp);
+ *retval.add_input() = noinline_node_name;
+ (*retval.mutable_attr())["T"].set_type(DT_INT32);
+ (*retval.mutable_attr())["index"].set_i(0);
+ Status status;
+ scope.graph()->AddNode(retval, &status);
+ TF_ASSERT_OK(status);
+
+ GraphDef expected;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected));
+
+ InstantiationResultForTest result;
+ // Verify that increment_fn has been copied to library.
+ TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
+
+ EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
+ EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
+ // Ignore the function library when comparing the graphs.
+ expected.clear_library();
+ TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+ }
+}
+
// Tests functionalizing OneLoopVar where the loop value is not used post the
// loop.
// Graph:
auto output_handle = b->Call(*result.computation, handles);
// The output handle of `Call` computation is a tuple type. Unzip it so
// that it can fit into future computations.
+ int computation_output = 0;
for (int64 i = 0; i < n->num_outputs(); ++i) {
if (result.outputs[i].is_constant) {
xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
} else {
- xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i));
+ xla_op_context.SetOutput(
+ i, b->GetTupleElement(output_handle, computation_output));
+ ++computation_output;
}
}
return b->first_error();
// Converts Tensorflow's graph control-flow constructs into functional
// control-flow that can be compiled into XLA code.
TF_RETURN_IF_ERROR(
- FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
+ FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
+ graph.get(), local_flib_def_.get()));
xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext(
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
}
}
+TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
+ // Define a function with one compile-time constant output and one
+ // data-dependent output.
+ // @function.Defun(noinline=True)
+ // foo(a) {b=7; return b, a; }
+ const Tensor seven = test::AsScalar<int>(7);
+ FunctionDef fdef = FunctionDefHelper::Create(
+ "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {},
+ {
+ {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}},
+ },
+ {{"a", "a_0"}, {"const", "Const:output:0"}});
+ (*fdef.mutable_attr())["_noinline"].set_b(true);
+ FunctionDefLibrary fdef_lib;
+ *(fdef_lib.add_function()) = fdef;
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
+ auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0);
+ NodeDef foo;
+ foo.set_name("foo");
+ foo.set_op("foo");
+ *foo.add_input() = "input_arg";
+ Status status;
+ scope.graph()->AddNode(foo, &status);
+ TF_ASSERT_OK(status);
+ NodeDef retval_1;
+ retval_1.set_name("retval_0");
+ retval_1.set_op(FunctionLibraryDefinition::kRetOp);
+ *retval_1.add_input() = "foo";
+ (*retval_1.mutable_attr())["T"].set_type(DT_INT32);
+ (*retval_1.mutable_attr())["index"].set_i(0);
+ scope.graph()->AddNode(retval_1, &status);
+ TF_ASSERT_OK(status);
+ NodeDef retval_2;
+ retval_2.set_name("retval_1");
+ retval_2.set_op(FunctionLibraryDefinition::kRetOp);
+ *retval_2.add_input() = "foo:1";
+ (*retval_2.mutable_attr())["T"].set_type(DT_INT32);
+ (*retval_2.mutable_attr())["index"].set_i(1);
+ scope.graph()->AddNode(retval_2, &status);
+ TF_ASSERT_OK(status);
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ }
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({1});
+
+ XlaCompiler::Options options = DefaultOptions();
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
+ options.flib_def = &flib_def;
+ XlaCompiler compiler(options);
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.resolve_compile_time_constants = true;
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
+ std::move(graph), args, &result));
+
+ ASSERT_EQ(2, result.outputs.size());
+ EXPECT_TRUE(result.outputs[0].is_constant);
+ test::ExpectTensorEqual<int32>(result.outputs[0].constant_value,
+ test::AsScalar(7));
+ EXPECT_FALSE(result.outputs[1].is_constant);
+}
+
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, ResourceManager) {
// Builds a graph that calls the dummy resource Op.