From: A. Unique TensorFlower Date: Thu, 17 May 2018 01:23:20 +0000 (-0700) Subject: [TF:XLA] Make noinline function work with control flow. X-Git-Tag: upstream/v1.9.0_rc1~106^2^2~9 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1ac9c483dc9c59101da21629184ba7ff40e27145;p=platform%2Fupstream%2Ftensorflow.git [TF:XLA] Make noinline function work with control flow. 1) Make the local function library for control flow self-contained. The control flow function could refer to a noinline function not defined in the local library. Copy the missing FunctionDefs from the glocal library to the local one. 2) Fix the index used to get the output shapes for functional nodes. PiperOrigin-RevId: 196922649 --- diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 8d1f268..42585ad 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -282,7 +282,58 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, return Status::OK(); } -Status FunctionalizeLoop(Graph* graph, Frame* frame, +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function refered by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << dump_graph::DumpGraphToFile("functionalize_before", *graph, @@ -489,6 +540,14 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } // Builds a While operator. NodeDef while_def; @@ -1365,6 +1424,12 @@ Status FunctionalizeCond::Functionalize(Graph* graph, // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); @@ -1434,7 +1499,8 @@ Status FunctionalizeControlFlow(Graph* graph, continue; } - TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 4d4ee30..d941041 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -22,9 +22,13 @@ limitations under the License. namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index e494f42..14977a9 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -299,6 +299,131 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { + FunctionDef fdef = FunctionDefHelper::Create( + "increment_fn", {"x:int32"}, {"add:int32"}, {}, + { + {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}}, + {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}}, + }, + {{"add", "add_0:z:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); + NodeDef increment_fn; + increment_fn.set_name(node_name); + increment_fn.set_op("increment_fn"); + *increment_fn.add_input() = "while/Identity"; + *increment_fn.add_input() = "^while/Identity"; + Status status; + graph->AddNode(increment_fn, &status); + return status; +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) +TEST(FunctionalizeControlFlow, NoinlineLoopBody) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source, + "while/while_context"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + + NodeDef next_iter; + next_iter.set_name("while/NextIteration"); + next_iter.set_op("NextIteration"); + *next_iter.add_input() = noinline_node_name; + (*next_iter.mutable_attr())["T"].set_type(DT_INT32); + + Status status; + Node* n = scope.graph()->AddNode(next_iter, &status); + TF_ASSERT_OK(status); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(n, 0, merge.output.node(), 1); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // Function increment_fn will be copied from lookup_lib to library. + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Tests functionalizing OneLoopVar where the loop value is not used post the // loop. // Graph: diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index b1cb76a..212f6f3 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -230,11 +230,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, auto output_handle = b->Call(*result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. + int computation_output = 0; for (int64 i = 0; i < n->num_outputs(); ++i) { if (result.outputs[i].is_constant) { xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { - xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + xla_op_context.SetOutput( + i, b->GetTupleElement(output_handle, computation_output)); + ++computation_output; } } return b->first_error(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 5a6db77..f709891 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -678,7 +678,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Converts Tensorflow's graph control-flow constructs into functional // control-flow that can be compiled into XLA code. TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); xla::XlaBuilder builder(name); XlaContext* context = new XlaContext( diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 5670545..55772ca 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -361,6 +362,76 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { } } +TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { + // Define a function with one compile-time constant output and one + // data-dependent output. + // @function.Defun(noinline=True) + // foo(a) {b=7; return b, a; } + const Tensor seven = test::AsScalar(7); + FunctionDef fdef = FunctionDefHelper::Create( + "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {}, + { + {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}}, + }, + {{"a", "a_0"}, {"const", "Const:output:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0); + NodeDef foo; + foo.set_name("foo"); + foo.set_op("foo"); + *foo.add_input() = "input_arg"; + Status status; + scope.graph()->AddNode(foo, &status); + TF_ASSERT_OK(status); + NodeDef retval_1; + retval_1.set_name("retval_0"); + retval_1.set_op(FunctionLibraryDefinition::kRetOp); + *retval_1.add_input() = "foo"; + (*retval_1.mutable_attr())["T"].set_type(DT_INT32); + (*retval_1.mutable_attr())["index"].set_i(0); + scope.graph()->AddNode(retval_1, &status); + TF_ASSERT_OK(status); + NodeDef retval_2; + retval_2.set_name("retval_1"); + retval_2.set_op(FunctionLibraryDefinition::kRetOp); + *retval_2.add_input() = "foo:1"; + (*retval_2.mutable_attr())["T"].set_type(DT_INT32); + (*retval_2.mutable_attr())["index"].set_i(1); + scope.graph()->AddNode(retval_2, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1}); + + XlaCompiler::Options options = DefaultOptions(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph), args, &result)); + + ASSERT_EQ(2, result.outputs.size()); + EXPECT_TRUE(result.outputs[0].is_constant); + test::ExpectTensorEqual(result.outputs[0].constant_value, + test::AsScalar(7)); + EXPECT_FALSE(result.outputs[1].is_constant); +} + // Tests compilation and execution of a graph that adds two tensors. TEST_F(XlaCompilerTest, ResourceManager) { // Builds a graph that calls the dummy resource Op.