From a9a3b98a76f1d4a8fb7a02e451fb71147a842f31 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 11 Apr 2018 09:43:32 -0700 Subject: [PATCH] Import FunctionDef as GrapplerFunctionItem Explicitly track function input arg expansion into Placeholders, and keep metadata to map between FunctionDef and GraphDef connectivity formats. PiperOrigin-RevId: 192462592 --- tensorflow/core/grappler/grappler_item.h | 3 +- .../core/grappler/optimizers/function_optimizer.cc | 29 +- .../grappler/optimizers/function_optimizer_test.cc | 16 +- tensorflow/core/grappler/utils/BUILD | 2 + tensorflow/core/grappler/utils/functions.cc | 385 ++++++++++++++++----- tensorflow/core/grappler/utils/functions.h | 116 ++++++- tensorflow/core/grappler/utils/functions_test.cc | 277 ++++++++++----- 7 files changed, 627 insertions(+), 201 deletions(-) diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 06bba54..45eed47 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -35,8 +35,9 @@ namespace grappler { // nodes, and potentially a set of nodes to feed. // TODO(volunteer_needed): turn this struct into a class. struct GrapplerItem { - GrapplerItem() {} + GrapplerItem() = default; GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef); + virtual ~GrapplerItem() = default; string id; // A unique id for this item diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 343c89a..6d67ead 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -38,11 +38,14 @@ class FunctionInliningContext { public: explicit FunctionInliningContext(const GrapplerItem& item, RewriterConfig::Toggle opt_level) - : library_(&item.graph.library()), - opt_level_(opt_level), - functions_(InliningCandidates(item)) {} + : opt_level_(opt_level), + functions_(InliningCandidates(item)), + function_library_(FunctionLibraryDefinition(OpRegistry::Global(), + item.graph.library())) {} - const FunctionDefLibrary& Library() const { return *library_; } + const FunctionLibraryDefinition& FunctionLibrary() const { + return function_library_; + } bool HasInlinedFunctions() const { return !functions_.empty(); } @@ -78,9 +81,9 @@ class FunctionInliningContext { return functions; } - const FunctionDefLibrary* library_; RewriterConfig::Toggle opt_level_; std::unordered_map functions_; + FunctionLibraryDefinition function_library_; TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext); }; @@ -150,11 +153,14 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, const std::unordered_map func_attr( func_node.attr().begin(), func_node.attr().end()); - std::unique_ptr item = - GrapplerItemFromFunctionDef(func, func_attr, ctx.Library()); - if (!item) { + GrapplerFunctionItem item; + Status item_status = + MakeGrapplerFunctionItem(func, func_attr, ctx.FunctionLibrary(), &item); + + if (!item_status.ok()) { return errors::InvalidArgument("Failed to inline function ", func_node.op(), - " instantiated by ", func_node.name()); + " instantiated by ", func_node.name(), + ". Error: ", item_status.error_message()); } std::unordered_map input_nodes; @@ -168,7 +174,7 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, TF_RETURN_IF_ERROR( HookInlinedFunctionInputs(func_node, func, func_attr, func_inputs)); - for (NodeDef& func_body_node : *item->graph.mutable_node()) { + for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) { if (input_nodes.find(func_body_node.name()) != input_nodes.end()) { CHECK_EQ(0, func_body_node.input_size()); // Turn input placeholders into identity nodes @@ -217,8 +223,9 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, // Hook inlined function outputs to IdentityN node NodeDef* func_outputs = optimized_graph->add_node(); + std::vector fetch = OutputTensors(item); TF_RETURN_IF_ERROR(HookInlinedFunctionOutputs(func_node, func, func_attr, - item->fetch, func_outputs)); + fetch, func_outputs)); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index fe26a56..099fe7c 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -92,13 +92,13 @@ TEST_F(FunctionOptimizerTest, SimpleFunction) { EXPECT_EQ(device, node.device()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y/x", node.input(0)); - EXPECT_EQ("y/scale:0", node.input(1)); + EXPECT_EQ("y/scale", node.input(1)); } else if (node.name() == "y") { count++; EXPECT_EQ("IdentityN", node.op()); EXPECT_EQ(device, node.device()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("y/y:0", node.input(0)); + EXPECT_EQ("y/y", node.input(0)); } else if (node.name() == "z") { count++; EXPECT_EQ("Identity", node.op()); @@ -180,13 +180,13 @@ TEST_F(FunctionOptimizerTest, FixedTypeFunction) { EXPECT_EQ(device, node.device()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y/x", node.input(0)); - EXPECT_EQ("y/two:0", node.input(1)); + EXPECT_EQ("y/two", node.input(1)); } else if (node.name() == "y") { count++; EXPECT_EQ("IdentityN", node.op()); EXPECT_EQ(device, node.device()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("y/y:0", node.input(0)); + EXPECT_EQ("y/y", node.input(0)); } else if (node.name() == "z") { count++; EXPECT_EQ("Identity", node.op()); @@ -264,13 +264,13 @@ TEST_F(FunctionOptimizerTest, FunctionWithOutputMapping) { EXPECT_EQ("Exp", node.op()); EXPECT_EQ(device, node.device()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("y/Linear_func:0", node.input(0)); + EXPECT_EQ("y/Linear_func", node.input(0)); } else if (node.name() == "y") { count++; EXPECT_EQ("IdentityN", node.op()); EXPECT_EQ(device, node.device()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("y/Exp:0", node.input(0)); + EXPECT_EQ("y/Exp", node.input(0)); } else if (node.name() == "z") { count++; EXPECT_EQ("Identity", node.op()); @@ -453,12 +453,12 @@ TEST_F(FunctionOptimizerTest, InlineFunctionWithNestedFunctionCall) { EXPECT_EQ("IdentityN", node.op()); EXPECT_EQ(kDevice, node.device()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("square/output/output:0", node.input(0)); + EXPECT_EQ("square/output/output", node.input(0)); } else if (node.name() == "square" && count++) { EXPECT_EQ("IdentityN", node.op()); EXPECT_EQ(kDevice, node.device()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("square/output:0", node.input(0)); + EXPECT_EQ("square/output", node.input(0)); } else if (node.name() == "outputs" && count++) { EXPECT_EQ("Identity", node.op()); EXPECT_EQ(kDevice, node.device()); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 7419c26..05d9cba 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -161,6 +161,8 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 4f286ce..dd0d918 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -24,50 +24,285 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { -std::unique_ptr GrapplerItemFromFunctionDef( - const FunctionDef& func, - const std::unordered_map& func_attr, - const FunctionDefLibrary& library) { - if (func.signature().name().empty()) { - LOG(ERROR) << "function name must be specified."; - return nullptr; +void GrapplerFunctionConnectivity::RegisterInputArgExpansion( + const InputArgExpansion& input_arg_expansion) { + input_arg_expansions_.insert( + {input_arg_expansion.input_name, input_arg_expansion}); +} + +void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs( + const string& node_name, const tensorflow::NameRangeMap& outputs) { + function_body_outputs_.insert({node_name, outputs}); +} + +Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( + const string& func_def_input, std::vector* graph_def_inputs) const { + using ::tensorflow::strings::Scanner; + + // Parse input format: "node_name[:node_output][:position]" + string node_name; + string node_output; + int position = -1; + + StringPiece capture; + StringPiece remaining; + + // Parse "node_name" + if (Scanner(func_def_input) + .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) + .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_name = string(capture.data(), capture.size()); } - std::unique_ptr new_item(new GrapplerItem()); - new_item->id = func.signature().name(); - - std::unordered_map port_map; - - // Add the function inputs as placeholder - for (const auto& inp : func.signature().input_arg()) { - NodeDef* ph = new_item->graph.add_node(); - ph->set_name(inp.name()); - ph->set_op("Placeholder"); - if (inp.type() != DT_INVALID) { - (*ph->mutable_attr())["T"].set_type(inp.type()); - } else { - auto it = func_attr.find(inp.type_attr()); - if (it == func_attr.end()) { - LOG(ERROR) << "Unknown type attribute " << inp.type_attr() - << " for function input " << inp.name(); - return nullptr; + + // Parse "node_output" if it exists + if (Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .One(strings::Scanner::LOWERLETTER) + .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_output = string(capture.data(), capture.size()); + } + + // Parse "position" if it exists + if (Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .Many(strings::Scanner::DIGIT) + .GetResult(nullptr, &capture)) { + CHECK(strings::safe_strto32(capture, &position)); + } + + // If "node_output" is not empty, it must be an output of a function body node + bool is_function_body_output = !node_output.empty(); + + // Function input argument: "node_name[:position]" + if (!is_function_body_output) { + auto input_arg = input_arg_expansions_.find(node_name); + if (input_arg != input_arg_expansions_.end()) { + const InputArgExpansion& input_arg_expansion = input_arg->second; + const auto& placeholders = input_arg_expansion.placeholders; + + if (position == -1) { + // If position is not defined use all placeholders + graph_def_inputs->reserve(placeholders.size()); + for (const string& placeholder : placeholders) { + graph_def_inputs->push_back(placeholder); + } } else { - (*ph->mutable_attr())["T"] = it->second; + if (position > input_arg_expansion.placeholders.size() - 1) { + return errors::InvalidArgument("Invalid input ", node_name, + "position: ", position, + " (out of range)"); + } + graph_def_inputs->push_back(input_arg_expansion.placeholders[position]); + } + + return Status::OK(); + } + } + + // Function body output: "node_name:node_output[:position]" + if (is_function_body_output) { + auto function_body_outputs = function_body_outputs_.find(node_name); + if (function_body_outputs != function_body_outputs_.end()) { + const tensorflow::NameRangeMap& outputs = function_body_outputs->second; + auto output = outputs.find(node_output); + if (output != outputs.end()) { + const auto& output_range = output->second; + + if (position == -1) { + // If position is not defined expand node output range + for (int i = output_range.first; i < output_range.second; ++i) { + i == 0 ? graph_def_inputs->push_back(node_name) + : graph_def_inputs->push_back( + strings::StrCat(node_name, ":", i)); + } + } else { + if (position > (output_range.second - output_range.first)) { + return errors::InvalidArgument( + "Invalid node ", node_name, " output ", node_output, + " position: ", position, " (out of range)"); + } + int pos = output_range.first + position; + pos == 0 ? graph_def_inputs->push_back(node_name) + : graph_def_inputs->push_back( + strings::StrCat(node_name, ":", pos)); + } + + return Status::OK(); } } - port_map[inp.name()] = inp.name(); } - // Add the function body to the graph. - FunctionLibraryDefinition func_def(OpRegistry::Global(), library); + return errors::InvalidArgument("Failed to expand a function def input: ", + func_def_input); +} + +Status GrapplerFunctionConnectivity::ExpandNodeInputs( + NodeDef* function_body_node) const { + std::vector expanded_inputs; + + for (const string& function_def_input : function_body_node->input()) { + if (!IsControlInput(function_def_input)) + TF_RETURN_IF_ERROR( + ExpandFunctionDefInput(function_def_input, &expanded_inputs)); + else + expanded_inputs.push_back(function_def_input); + } + + function_body_node->clear_input(); + for (const string& expanded_input : expanded_inputs) + function_body_node->add_input(expanded_input); + return Status::OK(); +} + +Status GrapplerFunctionItemBuilder::GetTypeAttr(const string& type_attr_name, + DataType* data_type) const { + auto it = func_attr_->find(type_attr_name); + if (it == func_attr_->end()) { + return errors::InvalidArgument("Type attribute ", type_attr_name, + " is not defined"); + } else if (it->second.type() == DT_INVALID) { + return errors::InvalidArgument("Type attribute ", type_attr_name, + " is not defined with a valid type"); + } else { + *data_type = it->second.type(); + } + return Status::OK(); +} + +Status GrapplerFunctionItemBuilder::GetArgType(const OpDef::ArgDef& arg, + DataType* data_type) const { + if (arg.type() != DT_INVALID) { + *data_type = arg.type(); + } else { + TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type)); + } + return Status::OK(); +} + +GrapplerFunctionItem::GrapplerFunctionItem( + const string& function_name, + const std::vector& input_arg_expansions, + const std::vector& output_arg_expansions, + GraphDef&& function_body) + : function_name_(function_name), + input_arg_expansions_(input_arg_expansions), + output_arg_expansions_(output_arg_expansions) { + graph.Swap(&function_body); +} + +const string& GrapplerFunctionItem::function_name() const { + return function_name_; +} + +const std::vector& GrapplerFunctionItem::inputs() const { + return input_arg_expansions_; +} + +const InputArgExpansion& GrapplerFunctionItem::input(int i) const { + return input_arg_expansions_[i]; +} + +const std::size_t GrapplerFunctionItem::input_size() const { + return input_arg_expansions_.size(); +} + +const std::vector& GrapplerFunctionItem::outputs() const { + return output_arg_expansions_; +} + +const OutputArgExpansion& GrapplerFunctionItem::output(int i) const { + return output_arg_expansions_[i]; +} + +const std::size_t GrapplerFunctionItem::output_size() const { + return output_arg_expansions_.size(); +} + +const GraphDef& GrapplerFunctionItem::function_body() const { return graph; } + +GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; } + +std::vector OutputTensors(const GrapplerFunctionItem& item) { + std::vector output_tensors; + for (const OutputArgExpansion& output : item.outputs()) { + for (const string& tensor : output.output_tensors) { + output_tensors.push_back(tensor); + } + } + return output_tensors; +} + +Status MakeGrapplerFunctionItem( + const FunctionDef& func, + const std::unordered_map& func_attr, + const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item) { + const OpDef& signature = func.signature(); + + if (signature.name().empty()) { + return errors::InvalidArgument("Function name must be specified"); + } + + // Helper methods to lookup function attributes + GrapplerFunctionItemBuilder builder(&func_attr); + + // Mapping from FunctionDef input format (name[:output][:position]) to + // GraphDef input format (name[:position]) + GrapplerFunctionConnectivity connectivity; + + std::vector inputs; + std::vector outputs; + GraphDef function_body; + + // TODO(ezhulenev): support functions with tensor sequence inputs/outputs + + // Make sure that there is no tensor sequences in outputs + for (const OpDef::ArgDef& output : signature.output_arg()) { + if (!output.type_list_attr().empty() || !output.number_attr().empty()) { + return errors::InvalidArgument( + "Outputs with sequence of tensors are not supported. Unsupported " + "output: ", + output.name()); + } + } + + // For each input argument create a placeholder in function body. + for (const OpDef::ArgDef& input : signature.input_arg()) { + if (!input.type_list_attr().empty() || !input.number_attr().empty()) { + return errors::InvalidArgument( + "Inputs with sequence of tensors are not supported. Unsupported " + "input: ", + input.name()); + } + + DataType input_data_type; + TF_RETURN_IF_ERROR(builder.GetArgType(input, &input_data_type)); + + NodeDef* placeholder = function_body.add_node(); + placeholder->set_name(input.name()); + placeholder->set_op("Placeholder"); + (*placeholder->mutable_attr())["T"].set_type(input_data_type); + + InputArgExpansion input_expansion{/*input_name=*/input.name(), + /*placeholders=*/{input.name()}}; + connectivity.RegisterInputArgExpansion(input_expansion); + inputs.push_back(input_expansion); + } + + // Add all function nodes to the function body + for (const NodeDef& func_def_node : func.node_def()) { + NodeDef* new_node = function_body.add_node(); + *new_node = func_def_node; - for (const NodeDef& node : func.node_def()) { - NodeDef* new_node = new_item->graph.add_node(); - *new_node = node; - // Replace the placeholder attribute values with the specified value. + // Replace the placeholder attribute values with the specified value for (auto& attr : *new_node->mutable_attr()) { const string& ph_name = attr.second.placeholder(); auto it = func_attr.find(ph_name); @@ -78,75 +313,39 @@ std::unique_ptr GrapplerItemFromFunctionDef( // Functions use a custom format to encode connectivity. Map these custom // strings to regular ones. + tensorflow::NameRangeMap outputs_range_map; const OpRegistrationData* registration; - Status status = func_def.LookUp(node.op(), ®istration); - if (!status.ok()) { - LOG(ERROR) << "Op " << node.op() << " not registered: " << status; - return nullptr; - } - - tensorflow::NameRangeMap inputs; - tensorflow::NameRangeMap outputs; - status = tensorflow::NameRangesForNode(node, registration->op_def, &inputs, - &outputs); - if (!status.ok()) { - LOG(ERROR) << "Op " << node.op() << " invalid: " << status; - return nullptr; - } - for (const auto& name_range : outputs) { - string port_prefix = - strings::StrCat(node.name(), ":", name_range.first, ":"); - int index_start = name_range.second.first; - int index_end = name_range.second.second; - for (int i = index_start; i < index_end; ++i) { - string port_id = strings::StrCat(port_prefix, i - index_start); - string port_name = strings::StrCat(node.name(), ":", i); - port_map[port_id] = port_name; - } - } + TF_RETURN_IF_ERROR(func_library.LookUp(func_def_node.op(), ®istration)); + TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode( + func_def_node, registration->op_def, nullptr, &outputs_range_map)); + connectivity.RegisterFunctionBodyOutputs(func_def_node.name(), + outputs_range_map); } - for (auto& node : *new_item->graph.mutable_node()) { - // Rewrite the inputs to use the normal naming convention. - for (int i = 0; i < node.input_size(); ++i) { - const string& input = node.input(i); - if (IsControlInput(input)) { - // No need to remap control dependencies. - continue; - } else { - auto it = port_map.find(input); - if (it == port_map.end()) { - LOG(ERROR) << "Unknown input: " << input; - return nullptr; - } - node.set_input(i, it->second); - } - } + // Rewrite inputs to use GraphDef format + for (NodeDef& node : *function_body.mutable_node()) { + TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node)); } - // Add the function outputs to the list of fetch nodes, taking into account - // the output mapping if any. - for (const auto& out : func.signature().output_arg()) { - auto it = func.ret().find(out.name()); - if (it != func.ret().end()) { - auto it2 = port_map.find(it->second); - if (it2 == port_map.end()) { - LOG(ERROR) << "Unknown output mapping: " << it->first << " to " - << it->second; - return nullptr; - } else { - new_item->fetch.emplace_back(it2->second); - } + // Add function outputs + for (const OpDef::ArgDef& out : signature.output_arg()) { + std::vector output_tensors; + auto ret = func.ret().find(out.name()); + if (ret != func.ret().end()) { + // Expand outputs using provided output mapping + TF_RETURN_IF_ERROR( + connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)); } else { - new_item->fetch.emplace_back(out.name()); + // Otherwise output must be one of the function inputs + TF_RETURN_IF_ERROR( + connectivity.ExpandFunctionDefInput(out.name(), &output_tensors)); } - } - // Add the function inputs to the list of feeds. - for (const auto& inp : func.signature().input_arg()) { - new_item->feed.emplace_back(inp.name(), Tensor()); + outputs.push_back({out.name(), output_tensors}); } - return new_item; + *item = GrapplerFunctionItem(signature.name(), inputs, outputs, + std::move(function_body)); + return Status::OK(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 8f9b7d8..60ea885 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -19,19 +19,125 @@ limitations under the License. #include #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/grappler/grappler_item.h" namespace tensorflow { - namespace grappler { -// Factory method for creating a GrapplerItem from a FunctionDef. -// Returns nullptr if the given function def cannot be converted. -std::unique_ptr GrapplerItemFromFunctionDef( +// Depending on the function instantiation attributes, input argument to the +// function might be a single tensor, list of tensors of the same type, or a +// list of tensors of different types. +// +// InputArgExpansion keeps track of the placeholders that were added to the +// function body in place of function inputs. +struct InputArgExpansion { + string input_name; // name of the function input argument + std::vector placeholders; // names of placeholder nodes in the + // function body +}; + +// Depending on the function instantiation attributes, output argument is mapped +// to one or more outputs of one of the function body nodes. +// +// OutputArgExpansion keeps mapping from a function output arg to the output +// tensors of a function body nodes, that compute function outputs. +struct OutputArgExpansion { + string output_name; // name of the function output argument + std::vector output_tensors; // names of output tensors from the + // function body graph nodes +}; + +// FunctionDef uses different connectivity encoding for the function body nodes, +// then a GraphDef (see function.proto for details). Input name in FunctionDef +// can potentially represent a sequence of tensors (instead just one tensor in +// GraphDef), we need to expand it when converting from FunctionDef to GraphDef, +// and fold it back when doing backward conversion. +class GrapplerFunctionConnectivity { + public: + void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion); + void RegisterFunctionBodyOutputs(const string& node_name, + const tensorflow::NameRangeMap& outputs); + + // Expand input encoded in FunctionDef format (name[:output][:position]) into + // multiple inputs in GraphDef format (name[:position]). + Status ExpandFunctionDefInput(const string& func_def_input, + std::vector* graph_def_inputs) const; + + // Update Node inputs from FunctionDef to GraphDef format + Status ExpandNodeInputs(NodeDef* function_body_node) const; + + // TODO(ezhulenev): fold GraphDef inputs back to FunctionDef format + // Status FoldGraphDefInputs(const std::vector graph_def_inputs, + // std::vector* function_def_inputs) const; + + private: + std::unordered_map input_arg_expansions_; + std::unordered_map function_body_outputs_; +}; + +// Helper methods to build GrapplerFunctionItem from a function def and function +// attributes. +class GrapplerFunctionItemBuilder { + public: + using FunctionAttr = std::unordered_map; + + explicit GrapplerFunctionItemBuilder(const FunctionAttr* func_attr) + : func_attr_(func_attr) {} + + // Get DataType from attributes by name. Return error if attribute is missing, + // or it doesn't define a valid data type. + Status GetTypeAttr(const string& type_attr_name, DataType* data_type) const; + + // Get argument data type. If data type is not explicitly defined, uses + // provided attribute name to look it up in function attributes. + Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const; + + private: + const FunctionAttr* func_attr_; // do not own +}; + +// A special case of GrapplerItem, constructed from a TensorFlow Function. +class GrapplerFunctionItem : public GrapplerItem { + public: + GrapplerFunctionItem() {} + GrapplerFunctionItem( + const string& function_name, + const std::vector& input_arg_expansions, + const std::vector& output_arg_expansions, + GraphDef&& function_body); + + const string& function_name() const; + + const std::vector& inputs() const; + const InputArgExpansion& input(int i) const; + const std::size_t input_size() const; + + const std::vector& outputs() const; + const OutputArgExpansion& output(int i) const; + const std::size_t output_size() const; + + const GraphDef& function_body() const; + GraphDef& mutable_function_body(); + + private: + string function_name_; + std::vector input_arg_expansions_; + std::vector output_arg_expansions_; +}; + +// Return all output tensors referenced by item output args. +std::vector OutputTensors(const GrapplerFunctionItem& item); + +// Make a GrapplerFunctionItem from the function definition and attributes. +// Return error if the given function def cannot be converted. +Status MakeGrapplerFunctionItem( const FunctionDef& func, const std::unordered_map& func_attr, - const FunctionDefLibrary& library); + const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index 6a7d766..1eb3298 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -28,6 +29,88 @@ namespace { class FunctionsTest : public ::testing::Test {}; +TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) { + GrapplerFunctionConnectivity connectivity; + + connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}}); + connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}}); + + connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}}); + connectivity.RegisterFunctionBodyOutputs("Func", + {{"o1", {0, 2}}, {"o2", {2, 4}}}); + + std::vector inputs; + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputA", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("inputA", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB", &inputs)); + ASSERT_EQ(2, inputs.size()); + EXPECT_EQ("inputB_0", inputs[0]); + EXPECT_EQ("inputB_1", inputs[1]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB:1", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("inputB_1", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Add:z", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Add", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1", &inputs)); + ASSERT_EQ(2, inputs.size()); + EXPECT_EQ("Func", inputs[0]); + EXPECT_EQ("Func:1", inputs[1]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2", &inputs)); + ASSERT_EQ(2, inputs.size()); + EXPECT_EQ("Func:2", inputs[0]); + EXPECT_EQ("Func:3", inputs[1]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:0", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:1", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func:1", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:0", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func:2", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:1", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func:3", inputs[0]); +} + +TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandNodeInputs) { + GrapplerFunctionConnectivity connectivity; + + connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}}); + connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}}); + + NodeDef node; + node.add_input("inputA:0"); + node.add_input("inputB"); + + TF_EXPECT_OK(connectivity.ExpandNodeInputs(&node)); + + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("inputA", node.input(0)); + EXPECT_EQ("inputB_0", node.input(1)); + EXPECT_EQ("inputB_1", node.input(2)); +} + TEST_F(FunctionsTest, FromSimpleFunctionDef) { const Tensor kTwo = test::AsScalar(2); FunctionDef func = FunctionDefHelper::Define( @@ -48,37 +131,45 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) { std::unordered_map func_attr; func_attr["T"].set_type(DT_FLOAT); - FunctionDefLibrary library; - std::unique_ptr item = - GrapplerItemFromFunctionDef(func, func_attr, library); - CHECK(item); - EXPECT_EQ("XTimesTwo", item->id); - EXPECT_EQ(4, item->graph.node_size()); - EXPECT_EQ(std::vector({"y:0"}), item->fetch); - EXPECT_EQ(1, item->feed.size()); - EXPECT_EQ("x", item->feed[0].first); - - for (const NodeDef &node : item->graph.node()) { - if (node.name() == "x") { + FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + + EXPECT_EQ("XTimesTwo", item.function_name()); + EXPECT_EQ(4, item.function_body().node_size()); + + EXPECT_EQ(1, item.input_size()); + EXPECT_EQ("x", item.input(0).input_name); + EXPECT_EQ(std::vector{"x"}, item.input(0).placeholders); + + EXPECT_EQ(1, item.output_size()); + EXPECT_EQ("y", item.output(0).output_name); + EXPECT_EQ("y", item.output(0).output_tensors[0]); + + int count = 0; + for (const NodeDef &node : item.function_body().node()) { + if (node.name() == "x" && count++) { EXPECT_EQ("Placeholder", node.op()); EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); EXPECT_EQ(0, node.input_size()); - } else if (node.name() == "two") { + } else if (node.name() == "two" && count++) { EXPECT_EQ("Const", node.op()); EXPECT_EQ(0, node.input_size()); - } else if (node.name() == "scale") { + } else if (node.name() == "scale" && count++) { EXPECT_EQ("Cast", node.op()); EXPECT_EQ(DT_FLOAT, node.attr().at("DstT").type()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("two:0", node.input(0)); - } else if (node.name() == "y") { + EXPECT_EQ("two", node.input(0)); + } else if (node.name() == "y" && count++) { EXPECT_EQ("Mul", node.op()); EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); - EXPECT_EQ("scale:0", node.input(1)); + EXPECT_EQ("scale", node.input(1)); } } + EXPECT_EQ(4, count); } TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) { @@ -115,45 +206,53 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) { std::unordered_map func_attr; func_attr["T"].set_type(DT_FLOAT); - FunctionDefLibrary library; - std::unique_ptr item = - GrapplerItemFromFunctionDef(func, func_attr, library); - CHECK(item); - EXPECT_EQ("SubGrad", item->id); - EXPECT_EQ(12, item->graph.node_size()); - EXPECT_EQ(std::vector({"dx:0", "dy:0"}), item->fetch); - EXPECT_EQ(3, item->feed.size()); - EXPECT_EQ("x", item->feed[0].first); - EXPECT_EQ("y", item->feed[1].first); - EXPECT_EQ("dz", item->feed[2].first); - - for (const NodeDef &node : item->graph.node()) { + FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + + EXPECT_EQ("SubGrad", item.function_name()); + EXPECT_EQ(12, item.function_body().node_size()); + + ASSERT_EQ(3, item.input_size()); + EXPECT_EQ("x", item.input(0).input_name); + EXPECT_EQ("y", item.input(1).input_name); + EXPECT_EQ("dz", item.input(2).input_name); + + ASSERT_EQ(2, item.output_size()); + EXPECT_EQ("dx", item.output(0).output_tensors[0]); + EXPECT_EQ("dy", item.output(1).output_tensors[0]); + + int count = 0; + for (const NodeDef &node : item.function_body().node()) { if (node.name() == "x" || node.name() == "y" || node.name() == "dz") { + count++; EXPECT_EQ("Placeholder", node.op()); EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); EXPECT_EQ(0, node.input_size()); - } else if (node.name() == "rx") { + } else if (node.name() == "rx" && count++) { EXPECT_EQ("BroadcastGradientArgs", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("sx:0", node.input(0)); - EXPECT_EQ("sy:0", node.input(1)); - } else if (node.name() == "sum_gx") { + EXPECT_EQ("sx", node.input(0)); + EXPECT_EQ("sy", node.input(1)); + } else if (node.name() == "sum_gx" && count++) { EXPECT_EQ("Sum", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("gx:0", node.input(0)); - EXPECT_EQ("rx:0", node.input(1)); - } else if (node.name() == "sum_gy") { + EXPECT_EQ("gx", node.input(0)); + EXPECT_EQ("rx", node.input(1)); + } else if (node.name() == "sum_gy" && count++) { EXPECT_EQ("Sum", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("gy:0", node.input(0)); + EXPECT_EQ("gy", node.input(0)); EXPECT_EQ("rx:1", node.input(1)); } } + EXPECT_EQ(6, count); } TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) { - FunctionDefLibrary library; - *library.add_function() = FunctionDefHelper::Define( + FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + TF_ASSERT_OK(library.AddFunctionDef(FunctionDefHelper::Define( // Name "Swap", // Args @@ -164,7 +263,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) { {"T: {float, double}"}, // Nodes {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}}, - {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); + {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}))); FunctionDef func = FunctionDefHelper::Create( // Name @@ -189,43 +288,47 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) { std::unordered_map func_attr; func_attr["T"].set_type(DT_FLOAT); - std::unique_ptr item = - GrapplerItemFromFunctionDef(func, func_attr, library); - for (const NodeDef &node : item->graph.node()) { + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + + int count = 0; + for (const NodeDef &node : item.function_body().node()) { if (node.name() == "x" || node.name() == "y") { + count++; EXPECT_EQ("Placeholder", node.op()); EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); EXPECT_EQ(0, node.input_size()); - } else if (node.name() == "a0") { + } else if (node.name() == "a0" && count++) { EXPECT_EQ("Swap", node.op()); EXPECT_EQ(3, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); EXPECT_EQ("^x2", node.input(2)); - } else if (node.name() == "a1") { + } else if (node.name() == "a1" && count++) { EXPECT_EQ("Swap", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("a0:0", node.input(0)); + EXPECT_EQ("a0", node.input(0)); EXPECT_EQ("a0:1", node.input(1)); - } else if (node.name() == "x2") { + } else if (node.name() == "x2" && count++) { EXPECT_EQ("Mul", node.op()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("x", node.input(1)); - } else if (node.name() == "y2") { + } else if (node.name() == "y2" && count++) { EXPECT_EQ("Mul", node.op()); EXPECT_EQ(3, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("y", node.input(1)); EXPECT_EQ("^a1", node.input(2)); - } else if (node.name() == "o") { + } else if (node.name() == "o" && count++) { EXPECT_EQ("Add", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("x2:0", node.input(0)); - EXPECT_EQ("y2:0", node.input(1)); + EXPECT_EQ("x2", node.input(0)); + EXPECT_EQ("y2", node.input(1)); } } + EXPECT_EQ(7, count); } TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) { @@ -245,28 +348,31 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) { {{"out", "Exp:y:0"}}); std::unordered_map func_attr; - FunctionDefLibrary library; - std::unique_ptr item = - GrapplerItemFromFunctionDef(func, func_attr, library); + FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); - EXPECT_EQ(1, item->fetch.size()); - EXPECT_EQ("Exp:0", item->fetch[0]); + EXPECT_EQ(1, item.output_size()); + EXPECT_EQ("Exp", item.output(0).output_tensors[0]); - for (const NodeDef &node : item->graph.node()) { - if (node.name() == "in") { + int count = 0; + for (const NodeDef &node : item.function_body().node()) { + if (node.name() == "in" && count++) { EXPECT_EQ("Placeholder", node.op()); EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); EXPECT_EQ(0, node.input_size()); - } else if (node.name() == "Linear_func") { + } else if (node.name() == "Linear_func" && count++) { EXPECT_EQ("Identity", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("in", node.input(0)); - } else if (node.name() == "Exp") { + } else if (node.name() == "Exp" && count++) { EXPECT_EQ("Exp", node.op()); EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("Linear_func:0", node.input(0)); + EXPECT_EQ("Linear_func", node.input(0)); } } + EXPECT_EQ(3, count); } TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) { @@ -285,20 +391,25 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) { {{"out0", "in0"}}); std::unordered_map func_attr; - FunctionDefLibrary library; - std::unique_ptr item = - GrapplerItemFromFunctionDef(func, func_attr, library); + FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); - EXPECT_EQ(3, item->fetch.size()); - EXPECT_EQ("in0", item->fetch[0]); - EXPECT_EQ("arg2", item->fetch[1]); - EXPECT_EQ("arg3", item->fetch[2]); + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); - EXPECT_EQ(5, item->graph.node_size()); - for (const NodeDef &node : item->graph.node()) { + EXPECT_EQ("ForwardInputs", item.function_name()); + EXPECT_EQ(5, item.function_body().node_size()); + + EXPECT_EQ(3, item.output_size()); + EXPECT_EQ("in0", item.output(0).output_tensors[0]); + EXPECT_EQ("arg2", item.output(1).output_tensors[0]); + EXPECT_EQ("arg3", item.output(2).output_tensors[0]); + + int count = 0; + for (const NodeDef &node : item.function_body().node()) { EXPECT_TRUE(node.name() == "in0" || node.name() == "in1" || node.name() == "arg2" || node.name() == "arg3" || node.name() == "arg4"); + count++; EXPECT_EQ("Placeholder", node.op()); if (node.name() == "arg3") { EXPECT_EQ(DT_INT32, node.attr().at("T").type()); @@ -306,6 +417,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) { EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); } } + EXPECT_EQ(5, count); } TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { @@ -325,24 +437,23 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { std::unordered_map func_attr; func_attr["T"].set_type(DT_FLOAT); - FunctionDefLibrary library; - std::unique_ptr item = - GrapplerItemFromFunctionDef(func, func_attr, library); + FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); - EXPECT_EQ(0, item->feed.size()); - EXPECT_EQ(1, item->fetch.size()); - EXPECT_EQ("o:0", item->fetch[0]); + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); - EXPECT_EQ(2, item->graph.node_size()); - const NodeDef &two = item->graph.node(0); + EXPECT_EQ(0, item.input_size()); + EXPECT_EQ(1, item.output_size()); + EXPECT_EQ("o", item.output(0).output_tensors[0]); + + EXPECT_EQ(2, item.function_body().node_size()); + const NodeDef &two = item.function_body().node(0); EXPECT_EQ("two", two.name()); EXPECT_EQ(0, two.input_size()); - const NodeDef &cast = item->graph.node(1); + const NodeDef &cast = item.function_body().node(1); EXPECT_EQ("o", cast.name()); EXPECT_EQ(1, cast.input_size()); - EXPECT_EQ("two:0", cast.input(0)); - - std::cout << item->graph.DebugString() << std::endl; + EXPECT_EQ("two", cast.input(0)); } } // namespace -- 2.7.4