#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
// everything to use it.
static const char* const kArgOp = "_Arg";
static const char* const kRetValOp = "_Retval";
-static const char* const kSendToHostOp = "_XlaSendToHost";
-static const char* const kRecvFromHostOp = "_XlaRecvFromHost";
+static const char* const kHostComputeOp = "_XlaHostCompute";
static const char* const kSendFromHostOp = "_XlaSendFromHost";
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
// Write a copy of the input graph to 'graph_out', where the subgraphs are
// replaced with calls to the new functions.
- Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
+ Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
+ FunctionLibraryDefinition* library);
private:
// A subgraph of the input, all marked with a common 'group_attribute'
// .. .
// RAH --> C --> SFH
//
- // The compiled cluster is as follows. STH is a SendToHost node which is the
- // source of a channel to the RAH node above. RFH is a RecvFromHost node which
- // is the destination of a channel from the SFH node above. There is a control
- // edge that ensures RFH follows STH, which is used in shape inference to
- // ensure that the shapes on the STH host channel are known before the RFH
- // channel is compiled.
+ // The compiled cluster is as follows. HC is a HostCompute node which is the
+ // source of a channel to the RAH node above and the destination of a channel
+ // from the SFH node above.
//
- // Arg --> B --> STH ..> RFH --> D --> Retval
+ // Arg --> B --> HC --> D --> Retval
//
- // The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most
- // one RAH and SFH in each compiled cluster. This design is preferred over
- // adding separate Arg/Retval nodes for each transmitted value because it
- // simplifies the host code that would like to limit communication between
- // host and device and, e.g., raise only one interrupt per channel rather than
- // one per transmitted value.
+ // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
+ // at most one RAH and SFH in each outside_compilation cluster. This design is
+ // preferred over adding separate Arg/Retval nodes for each transmitted value
+ // because it allows optimizations to the host code that would like to limit
+ // communication between host and device and, e.g., raise only one interrupt
+ // per channel rather than one per transmitted value.
+ //
+ // The shapes of the outputs from the HC node in general cannot be determined
+ // until the shapes of its inputs are known at compile time, since e.g.,
+ // above, the shape of C's outputs aren't known until the shape of its inputs
+ // are known. If the shapes of the HC's outputs can be determined during the
+ // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
+ // graph is stored in the shape_inference_graph attr. This graph can be used
+ // when compiling the HC Op to determined the shape of the SFH inputs given
+ // the shapes of any ancestor RAH outputs. If it can be determined that the
+ // shape of the SFH inputs will not be inferrable even once the shapes of the
+ // RAH outputs are known, an error is returned by the rewriter.
class Subgraph {
public:
// Creates a graph to build the subgraph in, if it doesn't already exist,
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out);
+ // Returns the names of all the outside_compilation subgraphs in this
+ // Subgraph.
+ void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
+
// Returns the Node that inputs to the function should be wired up to.
Node* GetCallNodeForInputs() const;
void RecordOutsideCompilationOutputOrControl(
const string& outside_compilation_id, const Edge* edge);
- // Adds the SendToHost nodes for each outside_compilation subgraph once the
- // edges have all been recorded via RecordOutsideCompilationInputOrControl.
- Status AddSendsToOutsideCompilation(
- const std::unordered_map<const Node*, Node*>& node_images);
-
- // Adds the RecvFromHost nodes for each outside_compilation subgraph once
- // the edges have all been recorded via
- // RecordOutsideCompilationOutputOrControl.
- Status AddRecvsFromOutsideCompilation(
+ // Adds the HostCompute nodes for each outside_compilation subgraph.
+ Status AddHostComputes(
+ const string& subgraph_name,
const std::unordered_map<const Node*, Node*>& node_images);
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
// all the downstream nodes of call_node_outputs.
void ConnectSequencerToOutputs(Graph* graph_out);
+ Status AddShapeInferenceInfo(
+ const string& outside_compilation_subgraph_name,
+ const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
+
+ Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
+
private:
struct OutsideCompilationSubgraph {
// Map from source (producer node/slot) tensors in the original graph to
- // input index (slot number in the SendToHost/RecvAtHost nodes that will
+ // input index (slot number in the HostCompute/RecvAtHost nodes that will
// be created) for the outside_compilation subgraph.
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
// outside_compilation subgraph. These are recorded by
// RecordOutsideCompilationInputOrControl while walking all the subgraph
// edges, and lifted control edges within the subgraph are added by
- // AddSendsToOutsideCompilation once the _SendToHost node has been
+ // AddSendsToOutsideCompilation once the _HostCompute node has been
// created. The matching control edge from _RecvAtHost to the
// destination is added by CopyEdgeToOutputGraph.
std::unordered_set<const Node*> control_inputs;
// Maps from source (producer node/slot) and destination (consumer
// node/slot) tensors in the original graph to output index (slot number
- // in the SendFromHost/RecvFromHost nodes that will be created) for the
+ // in the SendFromHost/HostCompute nodes that will be created) for the
// outside_compilation subgraph.
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
// containing compiled subgraph. These are recorded by
// RecordOutsideCompilationOutputOrControl while walking all the subgraph
// edges, and lifted control edges within the subgraph are added by
- // AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been
+ // AddRecvsFromToOutsideCompilation once the _HostCompute node has been
// created. The matching control edge from the source to _SendFromHost to
// the destination is added by CopyEdgeToOutputGraph.
std::unordered_set<const Node*> control_outputs;
- // _SendToHost node in the subgraph. Not owned.
- Node* send_to_host = nullptr;
+ // Name of the _HostCompute node in the subgraph.
+ string host_compute_name;
// _RecvAtHost node in the output graph. Not owned.
Node* recv_at_host = nullptr;
const std::unordered_map<const Node*, Node*>& node_images,
bool parallel_checking, Graph* graph_out);
+ // Constructs a minimal shape inference graph that can be used to determine
+ // the shape of send_node at the time that the subgraph is compiled.
+ // recv_at_host_nodes contains the names of all the recv_at_host nodes that
+ // send_node might depend on. These recv_at_host nodes have shapes that are
+ // not known during the rewrite pass, but will be known at compile time.
+ //
+ // If the shapes of all the inputs to send_node can be determined during the
+ // rewrite pass, on exit graphdef_out is empty and the shapes are returned in
+ // static_shape_out. Otherwise graphdef_out contains a graph that can be used
+ // for shape inference at compile time, where all the source nodes of the
+ // graph are either constants with known shapes, or nodes named in
+ // recv_at_host_nodes.
+ //
+ // A non-OK status is returned if neither of the above conditions can be
+ // satisfied, e.g., because send_node depends on a node that doesn't have a
+ // registered shape inference function.
+ Status DoStaticShapeInferenceForOutsideCompilationSend(
+ const Graph& graph_in, const ShapeRefiner& shape_refiner,
+ const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
+ FunctionLibraryDefinition* library,
+ std::vector<TensorShapeProto>* static_shape_out,
+ std::unique_ptr<GraphDef>* graphdef_out);
+
+ // Makes a copy of graph containing only nodes that are ancestors of at least
+ // one node in send_from_host_nodes and store it in pruned_graph. On exit
+ // nodes_images contains a mapping from nodes in graph to nodes in
+ // pruned_graph. All functions in the copied graph are inlined.
+ Status MakePrunedGraphCopyAndInline(
+ const Graph& graph, const std::vector<Node*>& sink_nodes,
+ std::unique_ptr<Graph>* pruned_graph,
+ std::unordered_map<const Node*, Node*>* node_images,
+ FunctionLibraryDefinition* library);
+
+ // Makes a copy of graph containing only nodes that are ancestors of a
+ // send_from_host node in an outside_compilation subgraph, and store it in
+ // pruned_graph. Also perform shape inference on the pruned graph, using
+ // shape_refiner. On exit node_images contains a mapping from nodes in graph
+ // to nodes in pruned_graph.
+ Status MakeGraphForOutsideCompilationSends(
+ const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
+ ShapeRefiner* shape_refiner,
+ std::unordered_map<const Node*, Node*>* node_images,
+ FunctionLibraryDefinition* library);
+
+ // Performs static shape inference, as far as possible, for the send_from_host
+ // nodes in each outside_compilation subgraph. Where it is not possible to
+ // determine the shape statically, stores a serialized GraphDef in the
+ // HostCompute 'shape_inference_graph' attr, to be used at compile time for
+ // final inference. If the shapes are known statically they are stored in the
+ // HostCompute 'shapes' attr.
+ Status GetShapeInfoForOutsideCompilationSends(
+ Graph* graph_out, FunctionLibraryDefinition* library);
+
const string group_attribute_;
const string outside_compilation_attribute_;
const Graph* graph_in_;
}
}
-Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
+Status Encapsulator::Subgraph::AddHostComputes(
+ const string& subgraph_name,
const std::unordered_map<const Node*, Node*>& node_images) {
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
const string& oc_subgraph_name = oc_subgraph_iter.first;
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
- if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
- // Build a _SendToHost node sending all the args of the appropriate
- // types.
- std::vector<DataType> dtypes(oc_subgraph.inputs.size(), DT_INVALID);
+ if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
+ !oc_subgraph.outputs_by_src.empty() ||
+ !oc_subgraph.control_outputs.empty()) {
+ // Build a _HostCompute node.
std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
+ std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
+ std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
+ DT_INVALID);
for (const auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
int input_index = input_src.second;
DataType dtype = src_node->output_type(src_slot);
- dtypes[input_index] = dtype;
inputs[input_index].Reset(src_image->name(), src_slot, dtype);
+ input_dtypes[input_index] = dtype;
}
- NodeDef send_def;
- NodeDefBuilder builder(
- strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"),
- kSendToHostOp);
- builder.Attr("dtypes", dtypes);
+ for (const auto& output : oc_subgraph.outputs_by_src) {
+ DataType dtype = output.first.dtype;
+ int output_index = output.second;
+ output_dtypes[output_index] = dtype;
+ }
+
+ NodeDef host_compute_def;
+ NodeDefBuilder builder(strings::StrCat("outside_compilation_",
+ oc_subgraph_name, "_host_compute"),
+ kHostComputeOp);
builder.Input(inputs);
- Status s = builder.Finalize(&send_def);
+ builder.Attr("Tinputs", input_dtypes);
+ builder.Attr("Toutputs", output_dtypes);
+ builder.Attr("key",
+ strings::StrCat("host_compute_channel_", subgraph_name, "_",
+ oc_subgraph_name));
+ Status s = builder.Finalize(&host_compute_def);
if (!s.ok()) return s;
- oc_subgraph.send_to_host = graph_->AddNode(send_def, &s);
+ Node* host_compute = graph_->AddNode(host_compute_def, &s);
if (!s.ok()) return s;
+ oc_subgraph.host_compute_name = host_compute->name();
- // Connect the _SendToHost node to its producers in the subgraph.
+ // Connect the _HostCompute node to its producers in the subgraph.
for (auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
int src_slot = input_src.first.slot;
int input_index = input_src.second;
- graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host,
- input_index);
+ graph_->AddEdge(src_image, src_slot, host_compute, input_index);
}
- // Connect the _SendToHost node to its control edge producers in the
+ // Connect the _HostCompute node to its control edge producers in the
// subgraph.
for (const auto& src_node : oc_subgraph.control_inputs) {
Node* src_image = node_images.at(src_node);
- graph_->AddControlEdge(src_image, oc_subgraph.send_to_host);
- }
- }
- }
-
- return Status::OK();
-}
-
-Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation(
- const std::unordered_map<const Node*, Node*>& node_images) {
- for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
- const string& oc_subgraph_name = oc_subgraph_iter.first;
- OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
- if (!oc_subgraph.outputs_by_src.empty() ||
- !oc_subgraph.control_outputs.empty()) {
- // Build a _RecvFromHost node producing all the outputs of the appropriate
- // types.
- std::vector<DataType> dtypes(oc_subgraph.outputs_by_src.size(),
- DT_INVALID);
-
- for (const auto& output : oc_subgraph.outputs_by_src) {
- DataType dtype = output.first.dtype;
- int output_index = output.second;
- dtypes[output_index] = dtype;
+ graph_->AddControlEdge(src_image, host_compute);
}
- NodeDef recv_def;
- NodeDefBuilder builder(
- strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"),
- kRecvFromHostOp);
- builder.Attr("dtypes", dtypes);
- Status s = builder.Finalize(&recv_def);
- if (!s.ok()) return s;
-
- Node* recv = graph_->AddNode(recv_def, &s);
- if (!s.ok()) return s;
-
- // Connect the consumers in the subgraph to the _RecvFromHost node.
+ // Connect the consumers in the subgraph to the _HostCompute node.
for (const auto& output : oc_subgraph.outputs_by_dst) {
const Node* dst_node = output.first.node;
Node* dst_image = node_images.at(dst_node);
int dst_slot = output.first.slot;
int output_index = output.second;
- graph_->AddEdge(recv, output_index, dst_image, dst_slot);
+ graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
}
- // Connect the control edge consumers in the subgraph to the _RecvFromHost
+ // Connect the control edge consumers in the subgraph to the _HostCompute
// node.
for (const auto& dst_node : oc_subgraph.control_outputs) {
Node* dst_image = node_images.at(dst_node);
- graph_->AddControlEdge(recv, dst_image);
- }
-
- // Add a control edge in the subgraph so that the _SendToHost node, if
- // any, is compiled before the _RecvFromHost node.
- if (oc_subgraph.send_to_host != nullptr) {
- graph_->AddControlEdge(oc_subgraph.send_to_host, recv);
+ graph_->AddControlEdge(host_compute, dst_image);
}
}
}
return Status::OK();
}
+Status Encapsulator::Subgraph::AddShapeInferenceInfo(
+ const string& outside_compilation_subgraph_name,
+ const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
+ OutsideCompilationSubgraph& oc_subgraph =
+ outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
+
+ Node* host_compute = nullptr;
+ for (Node* n : graph_->nodes()) {
+ if (n->name() == oc_subgraph.host_compute_name) {
+ host_compute = n;
+ break;
+ }
+ }
+ if (host_compute == nullptr) {
+ return errors::InvalidArgument(
+ "After rewriting subgraph ", outside_compilation_subgraph_name,
+ " there is no HostCompute Op for outside compilation subgraph ",
+ oc_subgraph.host_compute_name);
+ }
+
+ if (inference_graph == nullptr) {
+ host_compute->AddAttr("shape_inference_graph", "");
+ host_compute->AddAttr("shapes", shapes);
+ } else {
+ string serialized_graph;
+ if (!inference_graph->SerializeToString(&serialized_graph)) {
+ return errors::Internal(
+ "Failed to serialize graph for outside compilation subgraph ",
+ oc_subgraph.host_compute_name);
+ }
+ host_compute->AddAttr("shape_inference_graph", serialized_graph);
+ host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
+ }
+ return Status::OK();
+}
+
+Status Encapsulator::Subgraph::ReplaceFunctionDef(
+ FunctionLibraryDefinition* library) {
+ const string& name = call_node_def_.name();
+
+ FunctionDef fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
+
+ if (VLOG_IS_ON(1)) {
+ VLOG(2) << "Replace function def " << name;
+ dump_graph::DumpGraphToFile(
+ strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
+ library);
+ dump_graph::DumpFunctionDefToFile(
+ strings::StrCat("replace_encapsulate_fdef_", name), fdef);
+ }
+
+ TF_RETURN_IF_ERROR(library->RemoveFunction(name));
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ return Status::OK();
+}
+
Status Encapsulator::Subgraph::BuildParallelCheckOp(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out) {
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_recv"),
kRecvAtHostOp);
- builder.Attr("dtypes", dtypes);
+ builder.Attr("Toutputs", dtypes);
+ builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
+ "_", oc_subgraph_name));
Status s = builder.Finalize(&recv_def);
if (!s.ok()) return s;
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_send"),
kSendFromHostOp);
- builder.Attr("dtypes", dtypes);
+ builder.Attr("Tinputs", dtypes);
+ builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
+ "_", oc_subgraph_name));
builder.Input(inputs);
Status s = builder.Finalize(&send_def);
if (!s.ok()) return s;
return Status::OK();
}
+void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
+ std::vector<string>* names) const {
+ for (auto& entry : outside_compilation_subgraphs_) {
+ names->push_back(entry.first);
+ }
+}
+
Status Encapsulator::GetFunctionNameAttr(
Node const* node, string* attr, string* outside_compilation_attr) const {
Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
// single input and output node for it.
for (auto& entry : subgraphs_) {
Subgraph& subgraph = entry.second;
- TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images));
- TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images));
+ TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
}
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
return Status::OK();
}
-Status Encapsulator::BuildOutputGraph(bool parallel_checking,
- Graph* graph_out) {
+namespace {
+
+// Adds a dummy Const node to graph_out. The "constant" has the type of
+// data_type and the shape indicated in 'shape'. The dummy node is not a valid
+// Const node because it does not have any value defined, but this doesn't
+// matter because it will only be used subsequently for shape inference. (It
+// would be possible to add a switch statement over data_type to create a value
+// for the constant, but that would entail maintaining the logic as new types
+// are added, and is not necessary.)
+Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
+ Graph* graph_out) {
+ TensorProto dummy_proto;
+ dummy_proto.set_dtype(data_type);
+ *dummy_proto.mutable_tensor_shape() = shape;
+ // Don't set any value field in the proto, since it is only going to be used
+ // for shape inference.
+
+ GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
+ NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
+ options.op_registry());
+ node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
+ return options.FinalizeBuilder(&node_builder);
+}
+
+// Adds a copy of node_in to graph_out and adds the mapping to
+// copied_node_images.
+Status CopyShapeInferenceNodeToGraph(
+ Node* node_in, const Node* send_node,
+ const std::unordered_map<Node*, Node*>& dummy_node_images,
+ FunctionLibraryDefinition* library,
+ std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
+ // Once all the ancestor nodes have been added to graph_out, add this node
+ // and connect it to its ancestors.
+ Node* node_out = graph_out->CopyNode(node_in);
+ (*copied_node_images)[node_in] = node_out;
+ // Don't bother to build the shape inference graph if there's a node with no
+ // shape inference function, since it would just result in an error later at
+ // compile time.
+ const OpRegistrationData* op_reg_data;
+ TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
+ if (op_reg_data->shape_inference_fn == nullptr) {
+ return errors::InvalidArgument(
+ "Shape inference is not possible for outside_compilation "
+ "SendFromHost node ",
+ send_node->name(), " because it depends on node ", node_in->name(),
+ " which does not have a shape inference function registered.");
+ }
+ // Add all the edges to the newly copied node.
+ for (const Edge* in_edge : node_in->in_edges()) {
+ if (!in_edge->IsControlEdge()) {
+ Node* src = in_edge->src();
+ const auto iter = dummy_node_images.find(src);
+ if (iter == dummy_node_images.end()) {
+ // The src is a copied node so use the original output port.
+ graph_out->AddEdge((*copied_node_images)[in_edge->src()],
+ in_edge->src_output(), node_out,
+ in_edge->dst_input());
+ } else {
+ // The src is a dummy node so use output port 0.
+ graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
+ const Graph& graph_in, const ShapeRefiner& shape_refiner,
+ const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
+ FunctionLibraryDefinition* library,
+ std::vector<TensorShapeProto>* static_shape_out,
+ std::unique_ptr<GraphDef>* graphdef_out) {
+ // Maps from nodes in graph_in to nodes in graph_out.
+ //
+ // When an edge has fully defined shape the source node in graph_in is
+ // replaced in graph_out by a dummy constant node. The mapping from nodes
+ // in graph_in to dummy nodes is stored in dummy_node_images.
+ //
+ // When a node in graph_in has at least one ancestor that doesn't have fully
+ // defined shape, it is copied into graph_out. The mapping from nodes in
+ // graph_in to copied nodes is stored in copied_node_images.
+ //
+ // The two types of node are treated differently because, when adding edges to
+ // graph_out, an output from a dummy node always uses port 0, whereas an
+ // output from a copied node uses the same port that was used in graph_in.
+ std::unordered_map<Node*, Node*> dummy_node_images;
+ std::unordered_map<Node*, Node*> copied_node_images;
+
+ std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
+ graph_out->set_versions(graph_in.versions());
+ static_shape_out->resize(send_node->num_inputs());
+
+ // We don't use the standard ReverseDFS because we want to cut off traversal
+ // whenever we find an output with fully defined shape.
+ // TODO(misard) make this work properly in the presence of control flow.
+ struct Work {
+ Node* node;
+ bool leave; // Are we entering or leaving node?
+ };
+ std::vector<Work> stack({{send_node, false}});
+ std::vector<bool> visited(graph_in.num_node_ids(), false);
+ while (!stack.empty()) {
+ Work w = stack.back();
+ stack.pop_back();
+ Node* n = w.node;
+
+ if (w.leave) {
+ TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
+ n, send_node, dummy_node_images, library, &copied_node_images,
+ graph_out.get()));
+ } else {
+ if (visited[n->id()]) continue;
+ visited[n->id()] = true;
+
+ // Arrange to revisit when all done with all inputs.
+ stack.push_back(Work{n, true});
+
+ bool has_parent_with_unknown_shape = false;
+ for (const Edge* in_edge : n->in_edges()) {
+ if (!in_edge->IsControlEdge()) {
+ Node* src_node = in_edge->src();
+ int src_port = in_edge->src_output();
+ shape_inference::InferenceContext* context =
+ shape_refiner.GetContext(src_node);
+ shape_inference::ShapeHandle shape = context->output(src_port);
+ if (context->FullyDefined(shape)) {
+ // This ancestor has known shape, so instead of adding it to the
+ // stack, add a dummy node with that shape to graph_out and
+ // continue.
+ TensorShapeProto proto;
+ context->ShapeHandleToProto(shape, &proto);
+ dummy_node_images[src_node] = AddDummyShapedNode(
+ src_node->output_type(src_port), proto, graph_out.get());
+ if (n == send_node) {
+ (*static_shape_out)[in_edge->dst_input()] = proto;
+ }
+ } else {
+ if (!visited[src_node->id()]) {
+ has_parent_with_unknown_shape = true;
+ stack.push_back({src_node, false});
+ }
+ }
+ }
+ }
+ if (!has_parent_with_unknown_shape) {
+ if (n == send_node) {
+ // The shapes of all the inputs to send_node are statically known. We
+ // won't have to do any inference at compile time so return now: the
+ // shapes were stored in static_shape_out above.
+ graphdef_out->reset();
+ return Status::OK();
+ } else {
+ // Any shape that is being processed is either the original send node
+ // or has at least one output with statically-unknown shape. If the
+ // latter and it doesn't have any inputs with statically-unknown
+ // shape, then check that it is of the recv nodes that we can fill in
+ // the shape of at run-time later. If it isn't one of those, then we
+ // won't have any additional knowledge at compile time, so we already
+ // know we won't be able to do shape inference and we can return an
+ // error now.
+ if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
+ return errors::InvalidArgument(
+ "Shape inference is not possible for outside_compilation "
+ "SendFromHost node ",
+ send_node->name(), " because shape of node ", n->name(),
+ " will not be known at compilation time.");
+ }
+ }
+ }
+ }
+ }
+
+ graphdef_out->reset(new GraphDef());
+ graph_out->ToGraphDef(graphdef_out->get());
+
+ return Status::OK();
+}
+
+Status Encapsulator::MakePrunedGraphCopyAndInline(
+ const Graph& graph, const std::vector<Node*>& sink_nodes,
+ std::unique_ptr<Graph>* pruned_graph,
+ std::unordered_map<const Node*, Node*>* node_images,
+ FunctionLibraryDefinition* library) {
+ // First copy all ancestor nodes of sink_nodes into a new graph.
+ pruned_graph->reset(new Graph(library));
+ (*pruned_graph)->set_versions(graph.versions());
+ ReverseDFSFrom(graph, sink_nodes,
+ /*enter=*/nullptr,
+ /*leave=*/[&](Node* n) {
+ if (!n->IsSource()) {
+ Node* copied = (*pruned_graph)->CopyNode(n);
+ node_images->emplace(n, copied);
+ }
+ });
+
+ // Add all the edges between copied nodes.
+ for (auto entry : *node_images) {
+ const Node* orig = entry.first;
+ Node* image = entry.second;
+ for (const Edge* out_edge : orig->out_edges()) {
+ auto iter = node_images->find(out_edge->dst());
+ if (iter != node_images->end()) {
+ // The source and destination are both in the copied graph.
+ (*pruned_graph)
+ ->AddEdge(image, out_edge->src_output(), iter->second,
+ out_edge->dst_input());
+ }
+ }
+ }
+
+ // Find all the function call nodes, and inline them.
+ std::vector<Node*> function_nodes;
+ for (auto node : (*pruned_graph)->nodes()) {
+ const OpRegistrationData* op_reg_data;
+ TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
+ if (op_reg_data->is_function_op) {
+ function_nodes.push_back(node);
+ }
+ }
+ for (auto node : function_nodes) {
+ VLOG(2) << "Inlining function " << node->name();
+ const FunctionDef* fdef = library->Find(node->type_string());
+ if (fdef == nullptr) {
+ return errors::Internal("Failed to find function ", node->type_string(),
+ " in function library.");
+ }
+ FunctionBody* fbody = nullptr;
+ TF_RETURN_IF_ERROR(
+ FunctionDefToBodyHelper(*fdef, node->attrs(), library,
+ [library](const string& op, const OpDef** sig) {
+ return library->LookUpOpDef(op, sig);
+ },
+ &fbody));
+ InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
+ delete fbody;
+ }
+
+ return Status::OK();
+}
+
+Status Encapsulator::MakeGraphForOutsideCompilationSends(
+ const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
+ ShapeRefiner* shape_refiner,
+ std::unordered_map<const Node*, Node*>* node_images,
+ FunctionLibraryDefinition* library) {
+ // Find all the send_from_host nodes in all subgraphs, to use as roots for the
+ // pruning.
+ std::vector<Node*> send_from_host_nodes;
+ for (auto& subgraph_entry : subgraphs_) {
+ Subgraph& subgraph = subgraph_entry.second;
+ std::vector<string> outside_compilation_names;
+ subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
+ for (const auto& name : outside_compilation_names) {
+ Node* send_node = subgraph.GetSendFromHostNode(name);
+ if (send_node != nullptr) {
+ send_from_host_nodes.push_back(send_node);
+ }
+ }
+ }
+
+ // Make a copy of all the graph nodes needed to evaluate the send_from_host
+ // nodes, inlining any functions as needed.
+ TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
+ graph, send_from_host_nodes, pruned_graph, node_images, library));
+
+ // Perform shape inference on the pruned graph.
+ shape_refiner->set_require_shape_inference_fns(false);
+ FixupSourceAndSinkEdges(pruned_graph->get());
+ std::vector<Node*> post_order;
+ GetReversePostOrder(*(*pruned_graph), &post_order);
+ for (auto node : post_order) {
+ // Ignore the status returned by the shape_refiner. At this point we want
+ // the best effort shapes, even if no shape function is registered for a
+ // node.
+ Status status = shape_refiner->AddNode(node);
+ if (!status.ok()) {
+ VLOG(1) << "Shape inference failed for node: " << status;
+ }
+ }
+
+ return Status::OK();
+}
+
+Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
+ Graph* graph_out, FunctionLibraryDefinition* library) {
+ std::unique_ptr<Graph> pruned_graph;
+ ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
+ std::unordered_map<const Node*, Node*> node_images;
+ TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
+ *graph_out, &pruned_graph, &shape_refiner, &node_images, library));
+
+ for (auto& subgraph_entry : subgraphs_) {
+ Subgraph& subgraph = subgraph_entry.second;
+ // Find all the recv_at_host nodes in this subgraph.
+ std::vector<string> outside_compilation_names;
+ subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
+ std::unordered_set<string> recv_at_host_names;
+ for (const auto& name : outside_compilation_names) {
+ Node* recv_node = subgraph.GetRecvAtHostNode(name);
+ if (recv_node != nullptr) {
+ recv_at_host_names.insert(recv_node->name());
+ }
+ }
+ // For each send_from_host node, do as much shape inference as possible
+ // without knowing the shape of the recv_at_host nodes, and store the
+ // result, along with enough information to complete the job at compile time
+ // once the recv_at_host shapes are known.
+ for (const auto& name : outside_compilation_names) {
+ Node* send_node = subgraph.GetSendFromHostNode(name);
+ std::vector<TensorShapeProto> static_shape;
+ std::unique_ptr<GraphDef> graphdef;
+ if (send_node != nullptr) {
+ TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
+ *pruned_graph, shape_refiner, recv_at_host_names,
+ node_images[send_node], library, &static_shape, &graphdef));
+ if (graphdef == nullptr) {
+ VLOG(2) << "Send node " << send_node->name() << " shapes";
+ for (int i = 0; i < static_shape.size(); ++i) {
+ VLOG(2) << static_shape[i].DebugString();
+ }
+ } else {
+ VLOG(2) << "Send node " << send_node->name() << " graph\n"
+ << graphdef->DebugString();
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
+ }
+ if (!outside_compilation_names.empty()) {
+ TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
+ }
+ }
+
+ return Status::OK();
+}
+
+Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
+ FunctionLibraryDefinition* library) {
// Map from nodes in the input graph to nodes in the output graph.
std::unordered_map<const Node*, Node*> node_images;
TF_RETURN_IF_ERROR(
AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
+ TF_RETURN_IF_ERROR(
+ GetShapeInfoForOutsideCompilationSends(graph_out, library));
+
return Status::OK();
}
std::unique_ptr<Graph> out(new Graph(library));
out->set_versions(graph_in.versions());
TF_RETURN_IF_ERROR(
- encapsulator.BuildOutputGraph(parallel_checking, out.get()));
+ encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
*graph_out = std::move(out);
return Status::OK();
namespace tensorflow {
namespace {
+template <class Tkey, class Tvalue>
+bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
+ const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
+ const std::function<string(const Tkey&)>& key_to_string,
+ const std::function<string(const Tvalue&)>& value_to_string,
+ const std::function<bool(const Tkey&, const Tvalue&,
+ const Tvalue&)>& compare,
+ const string& map_name, string* diff) {
+ for (const auto& elt_a : a) {
+ const auto iter = b.find(elt_a.first);
+ if (iter == b.end()) {
+ if (diff) {
+ *diff = strings::StrCat(
+ map_name, " expected: contains element with key '",
+ key_to_string(elt_a.first), "' got: map has no such element");
+ }
+ return false;
+ }
+ if (!compare(elt_a.first, elt_a.second, iter->second)) {
+ if (diff) {
+ *diff = strings::StrCat(map_name, " expected: element with key '",
+ key_to_string(elt_a.first), " has value '",
+ value_to_string(elt_a.second), "' got: '",
+ value_to_string(iter->second), "'");
+ }
+ return false;
+ }
+ }
+ for (const auto& elt_b : b) {
+ const auto iter = a.find(elt_b.first);
+ if (iter == a.end()) {
+ if (diff) {
+ *diff = strings::StrCat(map_name, " got: contains element with key '",
+ key_to_string(elt_b.first),
+ "' expected: map has no such element");
+ }
+ return false;
+ }
+ }
+ return true;
+}
+
+bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
+ const string& diff_preamble, string* diff) {
+ if (a.op() != b.op()) {
+ if (diff) {
+ *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ ", expected op '", a.op(), "' got '", b.op());
+ }
+ return false;
+ }
+ if (a.device() != b.device()) {
+ if (diff) {
+ *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ ", expected device '", a.device(), "' got '",
+ b.device());
+ }
+ return false;
+ }
+ if (a.input_size() != b.input_size()) {
+ if (diff) {
+ *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ ", expected ", a.input_size(), " inputs got ",
+ b.input_size(), " expected:\n", a.DebugString(),
+ "\ngot:\n", b.DebugString());
+ }
+ return false;
+ }
+ for (int i = 0; i < a.input_size(); ++i) {
+ if (a.input(i) != b.input(i)) {
+ if (diff) {
+ *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ " input ", i, ", expected ", a.input(i),
+ " got ", b.input(i), " expected:\n",
+ a.DebugString(), "\ngot:\n", b.DebugString());
+ }
+ return false;
+ }
+ }
+ return EqualProtoMap<string, AttrValue>(
+ a.attr(), b.attr(), [](const string& s) { return s; },
+ [](const AttrValue& v) { return v.DebugString(); },
+ [](const string& key, const AttrValue& av, const AttrValue& bv) {
+ if (key == "shape_inference_graph") {
+ // Default serialization of GraphDef is unstable because maps don't
+ // serialize deterministically. Rather than go through the hoops to
+ // turn on deterministic serialization of this attr just for this
+ // test, add logic here to compare determinstically.
+ GraphDef ga;
+ if (!ga.ParseFromString(av.s())) {
+ return false;
+ }
+ GraphDef gb;
+ if (!gb.ParseFromString(bv.s())) {
+ return false;
+ }
+ return EqualGraphDef(ga, gb, nullptr);
+ } else {
+ return av.DebugString() == bv.DebugString();
+ }
+ },
+ strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
+ diff);
+}
+
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
string* diff) {
- // TODO(phawkins) use a more sophisticated equality test.
- if (a.DebugString() != b.DebugString()) {
+ if (a.signature().DebugString() != b.signature().DebugString()) {
if (diff) {
- *diff = strings::StrCat("Definition mismatch for function ",
+ *diff = strings::StrCat("Signature mismatch for function ",
a.signature().name(), ", expected:\n",
- a.DebugString(), "\ngot:\n", b.DebugString());
+ a.signature().DebugString(), "\ngot:\n",
+ b.signature().DebugString());
}
return false;
}
+ if (!EqualProtoMap<string, AttrValue>(
+ a.attr(), b.attr(), [](const string& s) { return s; },
+ [](const AttrValue& v) { return v.DebugString(); },
+ [](const string& key, const AttrValue& av, const AttrValue& bv) {
+ return av.DebugString() == bv.DebugString();
+ },
+ strings::StrCat("attr mismatch for function ", a.signature().name()),
+ diff)) {
+ return false;
+ }
+ if (!EqualProtoMap<string, string>(
+ a.ret(), b.ret(), [](const string& s) { return s; },
+ [](const string& s) { return s; },
+ [](const string& key, const string& av, const string& bv) {
+ return av == bv;
+ },
+ strings::StrCat("ret mismatch for function ", a.signature().name()),
+ diff)) {
+ return false;
+ }
+ for (int i = 0; i < a.node_def_size(); ++i) {
+ bool found = false;
+ for (int j = 0; j < b.node_def_size(); ++j) {
+ if (a.node_def(i).name() == b.node_def(j).name()) {
+ if (!EqualFunctionNodeDef(
+ a.node_def(i), b.node_def(j),
+ strings::StrCat("Function ", a.signature().name()), diff)) {
+ return false;
+ }
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ if (diff) {
+ *diff = strings::StrCat("Function ", a.signature().name(),
+ ", expected: has node '", a.node_def(i).name(),
+ "' got: no node of that name");
+ }
+ return false;
+ }
+ }
+ for (int i = 0; i < b.node_def_size(); ++i) {
+ bool found = false;
+ for (int j = 0; j < a.node_def_size(); ++j) {
+ if (b.node_def(i).name() == a.node_def(j).name()) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ if (diff) {
+ *diff = strings::StrCat("Function ", a.signature().name(),
+ ", got: has node '", b.node_def(i).name(),
+ "' expected: no node of that name");
+ }
+ return false;
+ }
+ }
return true;
}
// TODO(misard): remove these fake registrations once there are real Ops to be
// compiled.
-REGISTER_OP("_XlaSendToHost")
- .Input("input: dtypes")
- .Attr("dtypes: list(type) >= 0");
-
-REGISTER_OP("_XlaRecvFromHost")
- .Output("output: dtypes")
- .Attr("dtypes: list(type) >= 0");
+REGISTER_OP("_XlaHostCompute")
+ .Input("inputs: Tinputs")
+ .Output("outputs: Toutputs")
+ .Attr("Tinputs: list(type) >= 0")
+ .Attr("Toutputs: list(type) >= 0")
+ .Attr("key: string")
+ .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("_XlaSendFromHost")
- .Input("input: dtypes")
- .Attr("dtypes: list(type) >= 0");
+ .Input("input: Tinputs")
+ .Attr("Tinputs: list(type) >= 0")
+ .Attr("key: string")
+ .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("_XlaRecvAtHost")
- .Output("output: dtypes")
- .Attr("dtypes: list(type) >= 0");
-
-REGISTER_OP("InputTest").Output("o: float");
-
-REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
+ .Output("output: Toutputs")
+ .Attr("Toutputs: list(type) >= 0")
+ .Attr("key: string")
+ .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
+
+REGISTER_OP("InputTest")
+ .Output("o: float")
+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ });
+
+REGISTER_OP("InputTestShaped")
+ .Output("o: float")
+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Vector(2));
+ return Status::OK();
+ });
+
+REGISTER_OP("UnaryTest")
+ .Input("a: float")
+ .Output("o: float")
+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+ ::tensorflow::shape_inference::ShapeHandle o;
+ TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
+ c->set_output(0, o);
+ return Status::OK();
+ });
REGISTER_OP("BinaryTest")
.Input("a: float")
.Input("b: float")
- .Output("o: float");
+ .Output("o: float")
+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+ ::tensorflow::shape_inference::ShapeHandle o;
+ TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
+ c->set_output(0, o);
+ return Status::OK();
+ });
+REGISTER_OP("BinaryTest2")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("o: float")
+ .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("AddNLikeTest")
.Input("inputs: N * T")
return ops::SourceOp("InputTest", opts);
}
-Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
+Node* InputShaped(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("InputTestShaped", opts);
+}
+
+Node* KnownShape(const gtl::ArraySlice<int>& shape,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
+ opts.op_registry());
+ TensorProto value;
+ value.set_dtype(DT_FLOAT);
+ for (int dim : shape) {
+ value.mutable_tensor_shape()->add_dim()->set_size(dim);
+ }
+ return opts.WithAttr("value", value)
+ .WithAttr("dtype", DT_FLOAT)
+ .FinalizeBuilder(&node_builder);
+}
+
+Node* RecvAtHost(const string& key, const gtl::ArraySlice<DataType>& dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
"_XlaRecvAtHost", opts.op_registry());
- return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
+ return opts.WithAttr("Toutputs", dtypes)
+ .WithAttr("key", key)
+ .FinalizeBuilder(&node_builder);
}
-Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
- const gtl::ArraySlice<DataType>& dtypes,
+Node* SendFromHost(const string& key, const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
"_XlaSendFromHost", opts.op_registry());
node_builder.Input(inputs);
- return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
+ std::vector<DataType> dtypes;
+ for (const auto& node : inputs) {
+ dtypes.push_back(node.dt);
+ }
+ return opts.WithAttr("key", key)
+ .WithAttr("Tinputs", dtypes)
+ .FinalizeBuilder(&node_builder);
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
}
+Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
+}
+
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ string shape_string_expected;
+ {
+ GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
+ Node* recv =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+ shape.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
+ shape.opts().WithName("E"));
+ SendFromHost("host_compute_channel_F1_O1", {e},
+ shape.opts().WithName("outside_compilation_F1_O1_send"));
+ GraphDef shape_graph;
+ TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
+ EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+ }
+
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
{{"F"},
"BinaryTest",
- {"C:o:0", "outside_compilation_O1_recv:output:0"},
+ {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
- {"outside_compilation_O1_recv"}},
- {{"outside_compilation_O1_send"},
- "_XlaSendToHost",
+ {"outside_compilation_O1_host_compute"}},
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
{"C:o:0", "c:o:0"},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_string_expected},
+ {"shapes", gtl::ArraySlice<DataType>({})}},
{"c"}},
- {{"outside_compilation_O1_recv"},
- "_XlaRecvFromHost",
- {},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
- {"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv =
- RecvAtHost({DT_FLOAT, DT_FLOAT},
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
b2.opts().WithName("E").WithControlInputs({recv, b}));
- Node* send = SendFromHost({e}, {DT_FLOAT},
+ Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ string shape_string_expected_1;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* recv =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+ shape1.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
+ shape1.opts().WithName("E"));
+ SendFromHost("host_compute_channel_F1_O1", {e},
+ shape1.opts().WithName("outside_compilation_F1_O1_send"));
+ GraphDef shape1_graph;
+ TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
+ EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
+ }
+
+ string shape_string_expected_2;
+ {
+ GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
+ Node* recv1 =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+ shape2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
+ shape2.opts().WithName("E"));
+ Node* recv2 =
+ RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
+ shape2.opts().WithName("outside_compilation_F1_O2_recv"));
+ Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
+ SendFromHost("host_compute_channel_F1_O2", {h},
+ shape2.opts().WithName("outside_compilation_F1_O2_send"));
+ GraphDef shape2_graph;
+ TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
+ EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
+ }
+
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
- {{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
+ {{"I"},
+ "UnaryTest",
+ {"outside_compilation_O2_host_compute:outputs:0"}},
{{"F"},
"BinaryTest",
- {"C:o:0", "outside_compilation_O1_recv:output:0"},
+ {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
- {"outside_compilation_O1_recv"}},
- {{"outside_compilation_O2_send"},
- "_XlaSendToHost",
+ {"outside_compilation_O1_host_compute"}},
+ {{"outside_compilation_O2_host_compute"},
+ "_XlaHostCompute",
{"D:o:0", "F:o:0"},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F1_O2"},
+ {"shape_inference_graph", shape_string_expected_2},
+ {"shapes", gtl::ArraySlice<DataType>({})}},
{"F"}},
- {{"outside_compilation_O1_send"},
- "_XlaSendToHost",
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_string_expected_1},
+ {"shapes", gtl::ArraySlice<DataType>({})}},
{"D"}},
- {{"outside_compilation_O2_recv"},
- "_XlaRecvFromHost",
- {},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
- {"outside_compilation_O2_send"}},
- {{"outside_compilation_O1_recv"},
- "_XlaRecvFromHost",
- {},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
- {"outside_compilation_O1_send"}},
},
{{"i_0_retval", "I:o:0"}});
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv1 =
- RecvAtHost({DT_FLOAT, DT_FLOAT},
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
- Node* send1 = SendFromHost({e}, {DT_FLOAT},
+ Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* recv2 =
- RecvAtHost({DT_FLOAT, DT_FLOAT},
+ RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O2_recv"));
Node* g = Binary(e, ops::NodeOut(recv2, 1),
b2.opts().WithName("G").WithControlInputs({recv2, e}));
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
- Node* send2 = SendFromHost(
- {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
+ Node* send2 =
+ SendFromHost("host_compute_channel_F1_O2", {h},
+ b2.opts().WithName("outside_compilation_F1_O2_send"));
Node* s = NoOp(b2.opts()
.WithName("F1_sequencer")
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
- Node* a = Input(b1.opts().WithName("A"));
- Node* b = Input(b1.opts().WithName("B"));
+ Node* a = InputShaped(b1.opts().WithName("A"));
+ Node* b = InputShaped(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ string shape_string_expected;
+ {
+ GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
+ Node* recv =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+ shape.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
+ shape.opts().WithName("E"));
+ SendFromHost("host_compute_channel_F1_O1", {e},
+ shape.opts().WithName("outside_compilation_F1_O1_send"));
+ GraphDef shape_graph;
+ TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
+ EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+ }
+
+ TensorShapeProto shape_proto_expected;
+ shape_proto_expected.add_dim()->set_size(2);
+
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"f_0_retval:float", "d_0_retval:float"}, {},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
- {"C:o:0", "outside_compilation_O1_recv:output:0"},
+ {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
- {"outside_compilation_O1_recv"}},
- {{"outside_compilation_O1_send"},
- "_XlaSendToHost",
+ {"outside_compilation_O1_host_compute"}},
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_string_expected},
+ {"shapes", gtl::ArraySlice<DataType>({})}},
{"D"}},
- {{"outside_compilation_O1_recv"},
- "_XlaRecvFromHost",
- {},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
- {"outside_compilation_O1_send"}},
},
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
- {"f_0_arg", "outside_compilation_O1_recv:output:0"}},
- {{"outside_compilation_O1_send"},
- "_XlaSendToHost",
+ {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
{"G:o:0"},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
- {{"outside_compilation_O1_recv"},
- "_XlaRecvFromHost",
- {},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
- {"outside_compilation_O1_send"}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F2_O1"},
+ {"shape_inference_graph", ""},
+ {"shapes",
+ gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
},
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
- Node* a = Input(b2.opts().WithName("A"));
- Node* b = Input(b2.opts().WithName("B"));
+ Node* a = InputShaped(b2.opts().WithName("A"));
+ Node* b = InputShaped(b2.opts().WithName("B"));
Node* recv1 =
- RecvAtHost({DT_FLOAT, DT_FLOAT},
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
- Node* send1 = SendFromHost({e}, {DT_FLOAT},
+ Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
- Node* recv2 = RecvAtHost(
- {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
+ Node* recv2 =
+ RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT},
+ b2.opts().WithName("outside_compilation_F2_O1_recv"));
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
b2.opts().WithName("H").WithControlInput(s1));
- Node* send2 = SendFromHost(
- {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
+ Node* send2 =
+ SendFromHost("host_compute_channel_F2_O1", {h},
+ b2.opts().WithName("outside_compilation_F2_O1_send"));
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(e).Input(call1);
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
- Node* a = Input(b1.opts().WithName("A"));
+ Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ TensorShapeProto shape_proto_expected;
+ shape_proto_expected.add_dim()->set_size(2);
+
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
- {"D:o:0", "outside_compilation_O1_recv:output:0"}},
- {{"outside_compilation_O1_recv"},
- "_XlaRecvFromHost",
+ {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
{},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", ""},
+ {"shapes",
+ gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
},
{{"f_0_retval", "F:o:0"}});
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
- Node* a = Input(b2.opts().WithName("A"));
+ Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* e = Unary(a, b2.opts().WithName("E"));
- Node* send1 = SendFromHost(
- {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
+ Node* send1 =
+ SendFromHost("host_compute_channel_F1_O1", {e},
+ b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
- Node* a = Input(b1.opts().WithName("A"));
+ Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ TensorShapeProto shape_proto_expected;
+ shape_proto_expected.add_dim()->set_size(2);
+
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
- {"D:o:0", "outside_compilation_O1_recv:output:0"}},
- {{"outside_compilation_O1_send"},
- "_XlaSendToHost",
+ {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
{},
- {{"dtypes", gtl::ArraySlice<DataType>({})}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", ""},
+ {"shapes",
+ gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}},
{"D"}},
- {{"outside_compilation_O1_recv"},
- "_XlaRecvFromHost",
- {},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
- {"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
- Node* a = Input(b2.opts().WithName("A"));
+ Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 =
- RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
+ RecvAtHost("host_compute_channel_F1_O1", {},
+ b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
- Node* send1 = SendFromHost(
- {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
+ Node* send1 =
+ SendFromHost("host_compute_channel_F1_O1", {e},
+ b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
- {{"outside_compilation_O1_send"},
- "_XlaSendToHost",
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
{"D:o:0"},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", ""},
+ {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
},
{{"f_0_retval", "F:o:0"}});
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* recv1 = RecvAtHost(
- {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv1 =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+ b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
- {{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
- {{"outside_compilation_O1_send"},
- "_XlaSendToHost",
+ {{"F"},
+ "UnaryTest",
{"D:o:0"},
- {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
- {{"outside_compilation_O1_recv"},
- "_XlaRecvFromHost",
{},
- {{"dtypes", gtl::ArraySlice<DataType>({})}},
- {"outside_compilation_O1_send"}},
+ {"outside_compilation_O1_host_compute"}},
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
+ {"D:o:0"},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", ""},
+ {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
},
{{"f_0_retval", "F:o:0"}});
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* recv1 = RecvAtHost(
- {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv1 =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+ b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
- Node* send1 = SendFromHost({}, {},
+ Node* send1 = SendFromHost("host_compute_channel_F1_O1", {},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
+// Test for shape inference of outside compilation.
+TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
+ FunctionDefLibrary library;
+ GraphDef graphdef;
+
+ {
+ *library.add_function() = test::function::XTimesTwo();
+
+ GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
+ Node* a = InputShaped(b1.opts().WithName("A"));
+ Node* b = Input(b1.opts().WithName("B"));
+ // Give nodes 'c' and 'd' names that collide after lowercasing.
+ Node* c = Unary(a, b1.opts().WithName("C"));
+ Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
+ "_encapsulate", "F1"));
+ Node* e = BinaryUnknownShape(c, d,
+ b1.opts()
+ .WithName("E")
+ .WithControlInputs({b, d})
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* f = Binary(c, e,
+ b1.opts().WithName("F").WithControlInput(e).WithAttr(
+ "_encapsulate", "F1"));
+ Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
+ TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
+ }
+
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+
+ FunctionDefLibrary library_expected;
+ GraphDef graphdef_expected;
+
+ string shape_string_expected;
+ {
+ GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
+ Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0"));
+ Node* recv =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+ shape.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
+ SendFromHost("host_compute_channel_F1_O1", {e},
+ shape.opts().WithName("outside_compilation_F1_O1_send"));
+ GraphDef shape_graph;
+ TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
+ EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+ }
+
+ *library_expected.add_function() = test::function::XTimesTwo();
+ *library_expected.add_function() = FunctionDefHelper::Create(
+ "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {},
+ {
+ {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
+ {{"F"},
+ "BinaryTest",
+ {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
+ {},
+ {"outside_compilation_O1_host_compute"}},
+ {{"outside_compilation_O1_host_compute"},
+ "_XlaHostCompute",
+ {"c:o:0"},
+ {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_string_expected},
+ {"shapes", gtl::ArraySlice<DataType>({})}},
+ {"c"}},
+ },
+ {{"f_0_retval", "F:o:0"}});
+
+ {
+ std::unique_ptr<FunctionLibraryDefinition> lib_def(
+ new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
+ GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
+ Node* a = InputShaped(b2.opts().WithName("A"));
+ Node* b = Input(b2.opts().WithName("B"));
+ Node* c = Unary(a, b2.opts().WithName("C"));
+
+ NodeBuilder node_builder("F1", "F1", lib_def.get());
+ node_builder.Input(b).Input(c);
+ Node* call =
+ b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder);
+
+ Node* recv =
+ RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+ b2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* e = BinaryUnknownShape(
+ c, ops::NodeOut(recv, 0),
+ b2.opts().WithName("E").WithControlInputs({recv, b}));
+ Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
+ b2.opts()
+ .WithName("outside_compilation_F1_O1_send")
+ .WithControlInput(e));
+
+ Node* s = NoOp(
+ b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
+
+ Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
+ TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
+ }
+
+ TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
+ TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
+}
+
} // namespace
} // namespace tensorflow