From 4b83dea191761967e5d4c705caac6078f8360a1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Feb 2018 11:06:09 -0800 Subject: [PATCH] Add shape inference for outside_compilation graph rewrite. Pull out enough of the graph to enable inference of the shape of a SendFromHost Op once the shape of corresponding RecvAtHost Ops are known. END_PUBLIC Fixed open source build breaks. BEGIN_PUBLIC Automated g4 rollback of changelist 184169668 PiperOrigin-RevId: 184306845 --- .../compiler/jit/encapsulate_subgraphs_pass.cc | 649 +++++++++++++++++--- .../jit/encapsulate_subgraphs_pass_test.cc | 682 +++++++++++++++++---- tensorflow/core/framework/function.cc | 22 +- tensorflow/core/framework/function.h | 15 +- 4 files changed, 1127 insertions(+), 241 deletions(-) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 0de163d..9c372a0 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -30,12 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -141,8 +143,7 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; -static const char* const kSendToHostOp = "_XlaSendToHost"; -static const char* const kRecvFromHostOp = "_XlaRecvFromHost"; +static const char* const kHostComputeOp = "_XlaHostCompute"; static const char* const kSendFromHostOp = "_XlaSendFromHost"; static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; @@ -171,7 +172,8 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out); + Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -201,21 +203,29 @@ class Encapsulator { // .. . // RAH --> C --> SFH // - // The compiled cluster is as follows. STH is a SendToHost node which is the - // source of a channel to the RAH node above. RFH is a RecvFromHost node which - // is the destination of a channel from the SFH node above. There is a control - // edge that ensures RFH follows STH, which is used in shape inference to - // ensure that the shapes on the STH host channel are known before the RFH - // channel is compiled. + // The compiled cluster is as follows. HC is a HostCompute node which is the + // source of a channel to the RAH node above and the destination of a channel + // from the SFH node above. // - // Arg --> B --> STH ..> RFH --> D --> Retval + // Arg --> B --> HC --> D --> Retval // - // The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most - // one RAH and SFH in each compiled cluster. This design is preferred over - // adding separate Arg/Retval nodes for each transmitted value because it - // simplifies the host code that would like to limit communication between - // host and device and, e.g., raise only one interrupt per channel rather than - // one per transmitted value. + // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is + // at most one RAH and SFH in each outside_compilation cluster. This design is + // preferred over adding separate Arg/Retval nodes for each transmitted value + // because it allows optimizations to the host code that would like to limit + // communication between host and device and, e.g., raise only one interrupt + // per channel rather than one per transmitted value. + // + // The shapes of the outputs from the HC node in general cannot be determined + // until the shapes of its inputs are known at compile time, since e.g., + // above, the shape of C's outputs aren't known until the shape of its inputs + // are known. If the shapes of the HC's outputs can be determined during the + // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal + // graph is stored in the shape_inference_graph attr. This graph can be used + // when compiling the HC Op to determined the shape of the SFH inputs given + // the shapes of any ancestor RAH outputs. If it can be determined that the + // shape of the SFH inputs will not be inferrable even once the shapes of the + // RAH outputs are known, an error is returned by the rewriter. class Subgraph { public: // Creates a graph to build the subgraph in, if it doesn't already exist, @@ -246,6 +256,10 @@ class Encapsulator { const std::unordered_map& node_images, Graph* graph_out); + // Returns the names of all the outside_compilation subgraphs in this + // Subgraph. + void GetOutsideCompilationSubgraphNames(std::vector* names) const; + // Returns the Node that inputs to the function should be wired up to. Node* GetCallNodeForInputs() const; @@ -305,15 +319,9 @@ class Encapsulator { void RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge); - // Adds the SendToHost nodes for each outside_compilation subgraph once the - // edges have all been recorded via RecordOutsideCompilationInputOrControl. - Status AddSendsToOutsideCompilation( - const std::unordered_map& node_images); - - // Adds the RecvFromHost nodes for each outside_compilation subgraph once - // the edges have all been recorded via - // RecordOutsideCompilationOutputOrControl. - Status AddRecvsFromOutsideCompilation( + // Adds the HostCompute nodes for each outside_compilation subgraph. + Status AddHostComputes( + const string& subgraph_name, const std::unordered_map& node_images); // Creates the sequencer node if it doesn't exist, adding it to graph_out. @@ -323,10 +331,16 @@ class Encapsulator { // all the downstream nodes of call_node_outputs. void ConnectSequencerToOutputs(Graph* graph_out); + Status AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph); + + Status ReplaceFunctionDef(FunctionLibraryDefinition* library); + private: struct OutsideCompilationSubgraph { // Map from source (producer node/slot) tensors in the original graph to - // input index (slot number in the SendToHost/RecvAtHost nodes that will + // input index (slot number in the HostCompute/RecvAtHost nodes that will // be created) for the outside_compilation subgraph. std::unordered_map inputs; @@ -335,14 +349,14 @@ class Encapsulator { // outside_compilation subgraph. These are recorded by // RecordOutsideCompilationInputOrControl while walking all the subgraph // edges, and lifted control edges within the subgraph are added by - // AddSendsToOutsideCompilation once the _SendToHost node has been + // AddSendsToOutsideCompilation once the _HostCompute node has been // created. The matching control edge from _RecvAtHost to the // destination is added by CopyEdgeToOutputGraph. std::unordered_set control_inputs; // Maps from source (producer node/slot) and destination (consumer // node/slot) tensors in the original graph to output index (slot number - // in the SendFromHost/RecvFromHost nodes that will be created) for the + // in the SendFromHost/HostCompute nodes that will be created) for the // outside_compilation subgraph. std::unordered_map outputs_by_src; std::unordered_map outputs_by_dst; @@ -352,13 +366,13 @@ class Encapsulator { // containing compiled subgraph. These are recorded by // RecordOutsideCompilationOutputOrControl while walking all the subgraph // edges, and lifted control edges within the subgraph are added by - // AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been + // AddRecvsFromToOutsideCompilation once the _HostCompute node has been // created. The matching control edge from the source to _SendFromHost to // the destination is added by CopyEdgeToOutputGraph. std::unordered_set control_outputs; - // _SendToHost node in the subgraph. Not owned. - Node* send_to_host = nullptr; + // Name of the _HostCompute node in the subgraph. + string host_compute_name; // _RecvAtHost node in the output graph. Not owned. Node* recv_at_host = nullptr; @@ -516,6 +530,59 @@ class Encapsulator { const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out); + // Constructs a minimal shape inference graph that can be used to determine + // the shape of send_node at the time that the subgraph is compiled. + // recv_at_host_nodes contains the names of all the recv_at_host nodes that + // send_node might depend on. These recv_at_host nodes have shapes that are + // not known during the rewrite pass, but will be known at compile time. + // + // If the shapes of all the inputs to send_node can be determined during the + // rewrite pass, on exit graphdef_out is empty and the shapes are returned in + // static_shape_out. Otherwise graphdef_out contains a graph that can be used + // for shape inference at compile time, where all the source nodes of the + // graph are either constants with known shapes, or nodes named in + // recv_at_host_nodes. + // + // A non-OK status is returned if neither of the above conditions can be + // satisfied, e.g., because send_node depends on a node that doesn't have a + // registered shape inference function. + Status DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out); + + // Makes a copy of graph containing only nodes that are ancestors of at least + // one node in send_from_host_nodes and store it in pruned_graph. On exit + // nodes_images contains a mapping from nodes in graph to nodes in + // pruned_graph. All functions in the copied graph are inlined. + Status MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Makes a copy of graph containing only nodes that are ancestors of a + // send_from_host node in an outside_compilation subgraph, and store it in + // pruned_graph. Also perform shape inference on the pruned graph, using + // shape_refiner. On exit node_images contains a mapping from nodes in graph + // to nodes in pruned_graph. + Status MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Performs static shape inference, as far as possible, for the send_from_host + // nodes in each outside_compilation subgraph. Where it is not possible to + // determine the shape statically, stores a serialized GraphDef in the + // HostCompute 'shape_inference_graph' attr, to be used at compile time for + // final inference. If the shapes are known statically they are stored in the + // HostCompute 'shapes' attr. + Status GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library); + const string group_attribute_; const string outside_compilation_attribute_; const Graph* graph_in_; @@ -682,16 +749,20 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( } } -Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( +Status Encapsulator::Subgraph::AddHostComputes( + const string& subgraph_name, const std::unordered_map& node_images) { for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { const string& oc_subgraph_name = oc_subgraph_iter.first; OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; - if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { - // Build a _SendToHost node sending all the args of the appropriate - // types. - std::vector dtypes(oc_subgraph.inputs.size(), DT_INVALID); + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || + !oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + // Build a _HostCompute node. std::vector inputs(oc_subgraph.inputs.size()); + std::vector input_dtypes(oc_subgraph.inputs.size(), DT_INVALID); + std::vector output_dtypes(oc_subgraph.outputs_by_src.size(), + DT_INVALID); for (const auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; @@ -700,94 +771,64 @@ Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( int input_index = input_src.second; DataType dtype = src_node->output_type(src_slot); - dtypes[input_index] = dtype; inputs[input_index].Reset(src_image->name(), src_slot, dtype); + input_dtypes[input_index] = dtype; } - NodeDef send_def; - NodeDefBuilder builder( - strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"), - kSendToHostOp); - builder.Attr("dtypes", dtypes); + for (const auto& output : oc_subgraph.outputs_by_src) { + DataType dtype = output.first.dtype; + int output_index = output.second; + output_dtypes[output_index] = dtype; + } + + NodeDef host_compute_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), + kHostComputeOp); builder.Input(inputs); - Status s = builder.Finalize(&send_def); + builder.Attr("Tinputs", input_dtypes); + builder.Attr("Toutputs", output_dtypes); + builder.Attr("key", + strings::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); + Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; - oc_subgraph.send_to_host = graph_->AddNode(send_def, &s); + Node* host_compute = graph_->AddNode(host_compute_def, &s); if (!s.ok()) return s; + oc_subgraph.host_compute_name = host_compute->name(); - // Connect the _SendToHost node to its producers in the subgraph. + // Connect the _HostCompute node to its producers in the subgraph. for (auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); int src_slot = input_src.first.slot; int input_index = input_src.second; - graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host, - input_index); + graph_->AddEdge(src_image, src_slot, host_compute, input_index); } - // Connect the _SendToHost node to its control edge producers in the + // Connect the _HostCompute node to its control edge producers in the // subgraph. for (const auto& src_node : oc_subgraph.control_inputs) { Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, oc_subgraph.send_to_host); - } - } - } - - return Status::OK(); -} - -Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation( - const std::unordered_map& node_images) { - for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { - const string& oc_subgraph_name = oc_subgraph_iter.first; - OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; - if (!oc_subgraph.outputs_by_src.empty() || - !oc_subgraph.control_outputs.empty()) { - // Build a _RecvFromHost node producing all the outputs of the appropriate - // types. - std::vector dtypes(oc_subgraph.outputs_by_src.size(), - DT_INVALID); - - for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.first.dtype; - int output_index = output.second; - dtypes[output_index] = dtype; + graph_->AddControlEdge(src_image, host_compute); } - NodeDef recv_def; - NodeDefBuilder builder( - strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"), - kRecvFromHostOp); - builder.Attr("dtypes", dtypes); - Status s = builder.Finalize(&recv_def); - if (!s.ok()) return s; - - Node* recv = graph_->AddNode(recv_def, &s); - if (!s.ok()) return s; - - // Connect the consumers in the subgraph to the _RecvFromHost node. + // Connect the consumers in the subgraph to the _HostCompute node. for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; Node* dst_image = node_images.at(dst_node); int dst_slot = output.first.slot; int output_index = output.second; - graph_->AddEdge(recv, output_index, dst_image, dst_slot); + graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); } - // Connect the control edge consumers in the subgraph to the _RecvFromHost + // Connect the control edge consumers in the subgraph to the _HostCompute // node. for (const auto& dst_node : oc_subgraph.control_outputs) { Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(recv, dst_image); - } - - // Add a control edge in the subgraph so that the _SendToHost node, if - // any, is compiled before the _RecvFromHost node. - if (oc_subgraph.send_to_host != nullptr) { - graph_->AddControlEdge(oc_subgraph.send_to_host, recv); + graph_->AddControlEdge(host_compute, dst_image); } } } @@ -882,6 +923,63 @@ Status Encapsulator::Subgraph::BuildFunctionDef( return Status::OK(); } +Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph) { + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); + + Node* host_compute = nullptr; + for (Node* n : graph_->nodes()) { + if (n->name() == oc_subgraph.host_compute_name) { + host_compute = n; + break; + } + } + if (host_compute == nullptr) { + return errors::InvalidArgument( + "After rewriting subgraph ", outside_compilation_subgraph_name, + " there is no HostCompute Op for outside compilation subgraph ", + oc_subgraph.host_compute_name); + } + + if (inference_graph == nullptr) { + host_compute->AddAttr("shape_inference_graph", ""); + host_compute->AddAttr("shapes", shapes); + } else { + string serialized_graph; + if (!inference_graph->SerializeToString(&serialized_graph)) { + return errors::Internal( + "Failed to serialize graph for outside compilation subgraph ", + oc_subgraph.host_compute_name); + } + host_compute->AddAttr("shape_inference_graph", serialized_graph); + host_compute->AddAttr("shapes", std::vector()); + } + return Status::OK(); +} + +Status Encapsulator::Subgraph::ReplaceFunctionDef( + FunctionLibraryDefinition* library) { + const string& name = call_node_def_.name(); + + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); + + if (VLOG_IS_ON(1)) { + VLOG(2) << "Replace function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("replace_encapsulate_fdef_", name), fdef); + } + + TF_RETURN_IF_ERROR(library->RemoveFunction(name)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + return Status::OK(); +} + Status Encapsulator::Subgraph::BuildParallelCheckOp( const std::unordered_map& node_images, Graph* graph_out) { @@ -980,7 +1078,9 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); - builder.Attr("dtypes", dtypes); + builder.Attr("Toutputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); Status s = builder.Finalize(&recv_def); if (!s.ok()) return s; @@ -1020,7 +1120,9 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); - builder.Attr("dtypes", dtypes); + builder.Attr("Tinputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Input(inputs); Status s = builder.Finalize(&send_def); if (!s.ok()) return s; @@ -1062,6 +1164,13 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( return Status::OK(); } +void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( + std::vector* names) const { + for (auto& entry : outside_compilation_subgraphs_) { + names->push_back(entry.first); + } +} + Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); @@ -1220,8 +1329,7 @@ Status Encapsulator::SplitIntoSubgraphs() { // single input and output node for it. for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; - TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images)); - TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images)); + TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images)); } MarkGuaranteedConstants(*graph_in_, src_arg_pairs); @@ -1509,8 +1617,346 @@ Status Encapsulator::AddEdgesToOutputGraph( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, - Graph* graph_out) { +namespace { + +// Adds a dummy Const node to graph_out. The "constant" has the type of +// data_type and the shape indicated in 'shape'. The dummy node is not a valid +// Const node because it does not have any value defined, but this doesn't +// matter because it will only be used subsequently for shape inference. (It +// would be possible to add a switch statement over data_type to create a value +// for the constant, but that would entail maintaining the logic as new types +// are added, and is not necessary.) +Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, + Graph* graph_out) { + TensorProto dummy_proto; + dummy_proto.set_dtype(data_type); + *dummy_proto.mutable_tensor_shape() = shape; + // Don't set any value field in the proto, since it is only going to be used + // for shape inference. + + GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); + NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", + options.op_registry()); + node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); + return options.FinalizeBuilder(&node_builder); +} + +// Adds a copy of node_in to graph_out and adds the mapping to +// copied_node_images. +Status CopyShapeInferenceNodeToGraph( + Node* node_in, const Node* send_node, + const std::unordered_map& dummy_node_images, + FunctionLibraryDefinition* library, + std::unordered_map* copied_node_images, Graph* graph_out) { + // Once all the ancestor nodes have been added to graph_out, add this node + // and connect it to its ancestors. + Node* node_out = graph_out->CopyNode(node_in); + (*copied_node_images)[node_in] = node_out; + // Don't bother to build the shape inference graph if there's a node with no + // shape inference function, since it would just result in an error later at + // compile time. + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data)); + if (op_reg_data->shape_inference_fn == nullptr) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because it depends on node ", node_in->name(), + " which does not have a shape inference function registered."); + } + // Add all the edges to the newly copied node. + for (const Edge* in_edge : node_in->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src = in_edge->src(); + const auto iter = dummy_node_images.find(src); + if (iter == dummy_node_images.end()) { + // The src is a copied node so use the original output port. + graph_out->AddEdge((*copied_node_images)[in_edge->src()], + in_edge->src_output(), node_out, + in_edge->dst_input()); + } else { + // The src is a dummy node so use output port 0. + graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input()); + } + } + } + return Status::OK(); +} + +} // namespace + +Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out) { + // Maps from nodes in graph_in to nodes in graph_out. + // + // When an edge has fully defined shape the source node in graph_in is + // replaced in graph_out by a dummy constant node. The mapping from nodes + // in graph_in to dummy nodes is stored in dummy_node_images. + // + // When a node in graph_in has at least one ancestor that doesn't have fully + // defined shape, it is copied into graph_out. The mapping from nodes in + // graph_in to copied nodes is stored in copied_node_images. + // + // The two types of node are treated differently because, when adding edges to + // graph_out, an output from a dummy node always uses port 0, whereas an + // output from a copied node uses the same port that was used in graph_in. + std::unordered_map dummy_node_images; + std::unordered_map copied_node_images; + + std::unique_ptr graph_out(new Graph(graph_in.op_registry())); + graph_out->set_versions(graph_in.versions()); + static_shape_out->resize(send_node->num_inputs()); + + // We don't use the standard ReverseDFS because we want to cut off traversal + // whenever we find an output with fully defined shape. + // TODO(misard) make this work properly in the presence of control flow. + struct Work { + Node* node; + bool leave; // Are we entering or leaving node? + }; + std::vector stack({{send_node, false}}); + std::vector visited(graph_in.num_node_ids(), false); + while (!stack.empty()) { + Work w = stack.back(); + stack.pop_back(); + Node* n = w.node; + + if (w.leave) { + TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( + n, send_node, dummy_node_images, library, &copied_node_images, + graph_out.get())); + } else { + if (visited[n->id()]) continue; + visited[n->id()] = true; + + // Arrange to revisit when all done with all inputs. + stack.push_back(Work{n, true}); + + bool has_parent_with_unknown_shape = false; + for (const Edge* in_edge : n->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src_node = in_edge->src(); + int src_port = in_edge->src_output(); + shape_inference::InferenceContext* context = + shape_refiner.GetContext(src_node); + shape_inference::ShapeHandle shape = context->output(src_port); + if (context->FullyDefined(shape)) { + // This ancestor has known shape, so instead of adding it to the + // stack, add a dummy node with that shape to graph_out and + // continue. + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + dummy_node_images[src_node] = AddDummyShapedNode( + src_node->output_type(src_port), proto, graph_out.get()); + if (n == send_node) { + (*static_shape_out)[in_edge->dst_input()] = proto; + } + } else { + if (!visited[src_node->id()]) { + has_parent_with_unknown_shape = true; + stack.push_back({src_node, false}); + } + } + } + } + if (!has_parent_with_unknown_shape) { + if (n == send_node) { + // The shapes of all the inputs to send_node are statically known. We + // won't have to do any inference at compile time so return now: the + // shapes were stored in static_shape_out above. + graphdef_out->reset(); + return Status::OK(); + } else { + // Any shape that is being processed is either the original send node + // or has at least one output with statically-unknown shape. If the + // latter and it doesn't have any inputs with statically-unknown + // shape, then check that it is of the recv nodes that we can fill in + // the shape of at run-time later. If it isn't one of those, then we + // won't have any additional knowledge at compile time, so we already + // know we won't be able to do shape inference and we can return an + // error now. + if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because shape of node ", n->name(), + " will not be known at compilation time."); + } + } + } + } + } + + graphdef_out->reset(new GraphDef()); + graph_out->ToGraphDef(graphdef_out->get()); + + return Status::OK(); +} + +Status Encapsulator::MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // First copy all ancestor nodes of sink_nodes into a new graph. + pruned_graph->reset(new Graph(library)); + (*pruned_graph)->set_versions(graph.versions()); + ReverseDFSFrom(graph, sink_nodes, + /*enter=*/nullptr, + /*leave=*/[&](Node* n) { + if (!n->IsSource()) { + Node* copied = (*pruned_graph)->CopyNode(n); + node_images->emplace(n, copied); + } + }); + + // Add all the edges between copied nodes. + for (auto entry : *node_images) { + const Node* orig = entry.first; + Node* image = entry.second; + for (const Edge* out_edge : orig->out_edges()) { + auto iter = node_images->find(out_edge->dst()); + if (iter != node_images->end()) { + // The source and destination are both in the copied graph. + (*pruned_graph) + ->AddEdge(image, out_edge->src_output(), iter->second, + out_edge->dst_input()); + } + } + } + + // Find all the function call nodes, and inline them. + std::vector function_nodes; + for (auto node : (*pruned_graph)->nodes()) { + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data)); + if (op_reg_data->is_function_op) { + function_nodes.push_back(node); + } + } + for (auto node : function_nodes) { + VLOG(2) << "Inlining function " << node->name(); + const FunctionDef* fdef = library->Find(node->type_string()); + if (fdef == nullptr) { + return errors::Internal("Failed to find function ", node->type_string(), + " in function library."); + } + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR( + FunctionDefToBodyHelper(*fdef, node->attrs(), library, + [library](const string& op, const OpDef** sig) { + return library->LookUpOpDef(op, sig); + }, + &fbody)); + InlineFunctionBody(*library, pruned_graph->get(), node, fbody); + delete fbody; + } + + return Status::OK(); +} + +Status Encapsulator::MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // Find all the send_from_host nodes in all subgraphs, to use as roots for the + // pruning. + std::vector send_from_host_nodes; + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + if (send_node != nullptr) { + send_from_host_nodes.push_back(send_node); + } + } + } + + // Make a copy of all the graph nodes needed to evaluate the send_from_host + // nodes, inlining any functions as needed. + TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( + graph, send_from_host_nodes, pruned_graph, node_images, library)); + + // Perform shape inference on the pruned graph. + shape_refiner->set_require_shape_inference_fns(false); + FixupSourceAndSinkEdges(pruned_graph->get()); + std::vector post_order; + GetReversePostOrder(*(*pruned_graph), &post_order); + for (auto node : post_order) { + // Ignore the status returned by the shape_refiner. At this point we want + // the best effort shapes, even if no shape function is registered for a + // node. + Status status = shape_refiner->AddNode(node); + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << status; + } + } + + return Status::OK(); +} + +Status Encapsulator::GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library) { + std::unique_ptr pruned_graph; + ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); + std::unordered_map node_images; + TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( + *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + // Find all the recv_at_host nodes in this subgraph. + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + std::unordered_set recv_at_host_names; + for (const auto& name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(name); + if (recv_node != nullptr) { + recv_at_host_names.insert(recv_node->name()); + } + } + // For each send_from_host node, do as much shape inference as possible + // without knowing the shape of the recv_at_host nodes, and store the + // result, along with enough information to complete the job at compile time + // once the recv_at_host shapes are known. + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + std::vector static_shape; + std::unique_ptr graphdef; + if (send_node != nullptr) { + TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( + *pruned_graph, shape_refiner, recv_at_host_names, + node_images[send_node], library, &static_shape, &graphdef)); + if (graphdef == nullptr) { + VLOG(2) << "Send node " << send_node->name() << " shapes"; + for (int i = 0; i < static_shape.size(); ++i) { + VLOG(2) << static_shape[i].DebugString(); + } + } else { + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef->DebugString(); + } + } + TF_RETURN_IF_ERROR( + subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + } + if (!outside_compilation_names.empty()) { + TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); + } + } + + return Status::OK(); +} + +Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; @@ -1522,6 +1968,9 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, TF_RETURN_IF_ERROR( AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + GetShapeInfoForOutsideCompilationSends(graph_out, library)); + return Status::OK(); } @@ -1545,7 +1994,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get())); + encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); *graph_out = std::move(out); return Status::OK(); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index b100861..aed9cae 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -29,17 +29,181 @@ limitations under the License. namespace tensorflow { namespace { +template +bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, + const ::tensorflow::protobuf::Map& b, + const std::function& key_to_string, + const std::function& value_to_string, + const std::function& compare, + const string& map_name, string* diff) { + for (const auto& elt_a : a) { + const auto iter = b.find(elt_a.first); + if (iter == b.end()) { + if (diff) { + *diff = strings::StrCat( + map_name, " expected: contains element with key '", + key_to_string(elt_a.first), "' got: map has no such element"); + } + return false; + } + if (!compare(elt_a.first, elt_a.second, iter->second)) { + if (diff) { + *diff = strings::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), " has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); + } + return false; + } + } + for (const auto& elt_b : b) { + const auto iter = a.find(elt_b.first); + if (iter == a.end()) { + if (diff) { + *diff = strings::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); + } + return false; + } + } + return true; +} + +bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, + const string& diff_preamble, string* diff) { + if (a.op() != b.op()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); + } + return false; + } + if (a.device() != b.device()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); + } + return false; + } + if (a.input_size() != b.input_size()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); + } + return false; + } + for (int i = 0; i < a.input_size(); ++i) { + if (a.input(i) != b.input(i)) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), + " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } + } + return EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + if (key == "shape_inference_graph") { + // Default serialization of GraphDef is unstable because maps don't + // serialize deterministically. Rather than go through the hoops to + // turn on deterministic serialization of this attr just for this + // test, add logic here to compare determinstically. + GraphDef ga; + if (!ga.ParseFromString(av.s())) { + return false; + } + GraphDef gb; + if (!gb.ParseFromString(bv.s())) { + return false; + } + return EqualGraphDef(ga, gb, nullptr); + } else { + return av.DebugString() == bv.DebugString(); + } + }, + strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), + diff); +} + bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { - // TODO(phawkins) use a more sophisticated equality test. - if (a.DebugString() != b.DebugString()) { + if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Definition mismatch for function ", + *diff = strings::StrCat("Signature mismatch for function ", a.signature().name(), ", expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } + if (!EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + return av.DebugString() == bv.DebugString(); + }, + strings::StrCat("attr mismatch for function ", a.signature().name()), + diff)) { + return false; + } + if (!EqualProtoMap( + a.ret(), b.ret(), [](const string& s) { return s; }, + [](const string& s) { return s; }, + [](const string& key, const string& av, const string& bv) { + return av == bv; + }, + strings::StrCat("ret mismatch for function ", a.signature().name()), + diff)) { + return false; + } + for (int i = 0; i < a.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < b.node_def_size(); ++j) { + if (a.node_def(i).name() == b.node_def(j).name()) { + if (!EqualFunctionNodeDef( + a.node_def(i), b.node_def(j), + strings::StrCat("Function ", a.signature().name()), diff)) { + return false; + } + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); + } + return false; + } + } + for (int i = 0; i < b.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < a.node_def_size(); ++j) { + if (b.node_def(i).name() == a.node_def(j).name()) { + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); + } + return false; + } + } return true; } @@ -84,29 +248,64 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, // TODO(misard): remove these fake registrations once there are real Ops to be // compiled. -REGISTER_OP("_XlaSendToHost") - .Input("input: dtypes") - .Attr("dtypes: list(type) >= 0"); - -REGISTER_OP("_XlaRecvFromHost") - .Output("output: dtypes") - .Attr("dtypes: list(type) >= 0"); +REGISTER_OP("_XlaHostCompute") + .Input("inputs: Tinputs") + .Output("outputs: Toutputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaSendFromHost") - .Input("input: dtypes") - .Attr("dtypes: list(type) >= 0"); + .Input("input: Tinputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaRecvAtHost") - .Output("output: dtypes") - .Attr("dtypes: list(type) >= 0"); - -REGISTER_OP("InputTest").Output("o: float"); - -REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); + .Output("output: Toutputs") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); + +REGISTER_OP("InputTest") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }); + +REGISTER_OP("InputTestShaped") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }); + +REGISTER_OP("UnaryTest") + .Input("a: float") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); REGISTER_OP("BinaryTest") .Input("a: float") .Input("b: float") - .Output("o: float"); + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); +REGISTER_OP("BinaryTest2") + .Input("a: float") + .Input("b: float") + .Output("o: float") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("AddNLikeTest") .Input("inputs: N * T") @@ -124,22 +323,48 @@ Node* Input(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTest", opts); } -Node* RecvAtHost(const gtl::ArraySlice& dtypes, +Node* InputShaped(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("InputTestShaped", opts); +} + +Node* KnownShape(const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", + opts.op_registry()); + TensorProto value; + value.set_dtype(DT_FLOAT); + for (int dim : shape) { + value.mutable_tensor_shape()->add_dim()->set_size(dim); + } + return opts.WithAttr("value", value) + .WithAttr("dtype", DT_FLOAT) + .FinalizeBuilder(&node_builder); +} + +Node* RecvAtHost(const string& key, const gtl::ArraySlice& dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); - return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); + return opts.WithAttr("Toutputs", dtypes) + .WithAttr("key", key) + .FinalizeBuilder(&node_builder); } -Node* SendFromHost(const std::vector& inputs, - const gtl::ArraySlice& dtypes, +Node* SendFromHost(const string& key, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); - return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); + std::vector dtypes; + for (const auto& node : inputs) { + dtypes.push_back(node.dt); + } + return opts.WithAttr("key", key) + .WithAttr("Tinputs", dtypes) + .FinalizeBuilder(&node_builder); } Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { @@ -151,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b, return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); } +Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts); +} + Node* AddNLike(const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; @@ -576,6 +806,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, @@ -584,19 +829,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "c:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, {"c"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -612,11 +856,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* call = b2.opts().FinalizeBuilder(&node_builder); Node* recv = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost({e}, {DT_FLOAT}, + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -674,37 +918,71 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected_1; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape1.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape1.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape1_graph; + TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); + EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + } + + string shape_string_expected_2; + { + GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + shape2.opts().WithName("E")); + Node* recv2 = + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O2_recv")); + Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); + SendFromHost("host_compute_channel_F1_O2", {h}, + shape2.opts().WithName("outside_compilation_F1_O2_send")); + GraphDef shape2_graph; + TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); + EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + } + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, - {{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}}, + {{"I"}, + "UnaryTest", + {"outside_compilation_O2_host_compute:outputs:0"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O2_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O2_host_compute"}, + "_XlaHostCompute", {"D:o:0", "F:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", shape_string_expected_2}, + {"shapes", gtl::ArraySlice({})}}, {"F"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected_1}, + {"shapes", gtl::ArraySlice({})}}, {"D"}}, - {{"outside_compilation_O2_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O2_send"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"i_0_retval", "I:o:0"}}); @@ -720,23 +998,24 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* call = b2.opts().FinalizeBuilder(&node_builder); Node* recv1 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost({e}, {DT_FLOAT}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); Node* recv2 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_recv")); Node* g = Binary(e, ops::NodeOut(recv2, 1), b2.opts().WithName("G").WithControlInputs({recv2, e})); Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); - Node* send2 = SendFromHost( - {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send")); + Node* send2 = + SendFromHost("host_compute_channel_F1_O2", {h}, + b2.opts().WithName("outside_compilation_F1_O2_send")); Node* s = NoOp(b2.opts() .WithName("F1_sequencer") @@ -758,8 +1037,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); - Node* b = Input(b1.opts().WithName("B")); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = InputShaped(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); @@ -791,6 +1070,24 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float", "d_0_retval:float"}, {}, @@ -799,19 +1096,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, {"D"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); @@ -822,16 +1118,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", - {"f_0_arg", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"G:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F2_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -839,15 +1135,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); - Node* b = Input(b2.opts().WithName("B")); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = InputShaped(b2.opts().WithName("B")); Node* recv1 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost({e}, {DT_FLOAT}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -857,12 +1153,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* s1 = NoOp( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); - Node* recv2 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* recv2 = + RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F2_O1_recv")); Node* h = Binary(ops::NodeOut(call1, 1), recv2, b2.opts().WithName("H").WithControlInput(s1)); - Node* send2 = SendFromHost( - {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send")); + Node* send2 = + SendFromHost("host_compute_channel_F2_O1", {h}, + b2.opts().WithName("outside_compilation_F2_O1_send")); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(e).Input(call1); @@ -888,7 +1186,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); + Node* a = InputShaped(b1.opts().WithName("A")); Node* b = Input(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = @@ -908,6 +1206,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, { @@ -915,11 +1216,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"D:o:0", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -927,12 +1233,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); + Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); Node* e = Unary(a, b2.opts().WithName("E")); - Node* send1 = SendFromHost( - {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -954,7 +1261,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); + Node* a = InputShaped(b1.opts().WithName("A")); Node* b = Input(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = @@ -975,6 +1282,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, { @@ -982,17 +1292,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"D:o:0", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {}, - {{"dtypes", gtl::ArraySlice({})}}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}, {"D"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1000,14 +1310,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); + Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); Node* recv1 = - RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + RecvAtHost("host_compute_channel_F1_O1", {}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); - Node* send1 = SendFromHost( - {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -1055,10 +1367,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1069,8 +1385,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); @@ -1118,16 +1435,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, - {{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"F"}, + "UnaryTest", {"D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", {}, - {{"dtypes", gtl::ArraySlice({})}}, - {"outside_compilation_O1_send"}}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1138,10 +1458,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); - Node* send1 = SendFromHost({}, {}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -1215,5 +1536,110 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } +// Test for shape inference of outside compilation. +TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C")); + Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Node* e = BinaryUnknownShape(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}}, + {{"F"}, + "BinaryTest", + {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"c:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, + {"c"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + Node* c = Unary(a, b2.opts().WithName("C")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(b).Input(c); + Node* call = + b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); + + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape( + c, ops::NodeOut(recv, 0), + b2.opts().WithName("E").WithControlInputs({recv, b})); + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* s = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index d6b5761..eae8e6c 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1064,26 +1064,36 @@ Status FunctionLibraryDefinition::AddLibrary( return Status::OK(); } -void FunctionLibraryDefinition::RemoveFunction(const string& func) { +Status FunctionLibraryDefinition::RemoveFunction(const string& func) { const auto& i = function_defs_.find(func); - DCHECK(i != function_defs_.end()); + if (i == function_defs_.end()) { + return errors::InvalidArgument("Tried to remove non-existent function ", + func); + } function_defs_.erase(i); + return Status::OK(); } -void FunctionLibraryDefinition::RemoveGradient(const string& func) { +Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); - DCHECK(i != func_grad_.end()); + if (i == func_grad_.end()) { + return errors::InvalidArgument("Tried to remove non-existent gradient ", + func); + } func_grad_.erase(i); + return Status::OK(); } void FunctionLibraryDefinition::Remove( const std::vector& funcs, const std::vector& funcs_with_grads) { for (const string& f : funcs) { - RemoveFunction(f); + Status s = RemoveFunction(f); + DCHECK(s.ok()); } for (const string& f : funcs_with_grads) { - RemoveGradient(f); + Status s = RemoveGradient(f); + DCHECK(s.ok()); } } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index b933ee0..7d0e156 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -312,6 +312,14 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // This operation is atomic. Status AddGradientDef(const GradientDef& grad); + // Remove function `func` from the library. Returns non-OK Status unless + // `func` is in the library. + Status RemoveFunction(const string& func); + + // Remove gradient of function `func` from the library. Returns non-OK Status + // unless `func` has a gradient. + Status RemoveGradient(const string& func); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. @@ -384,13 +392,6 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // attr from. const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; - // Remove function `func` from the library. `func` must be in the library. - void RemoveFunction(const string& func); - - // Remove gradient of function `func` from the library. `func` must have - // a gradient. - void RemoveGradient(const string& func); - // Remove all functions in `funcs` and all gradients of // functions in `funcs_with_grads` from this library. void Remove(const std::vector& funcs, -- 2.7.4