Add shape inference for outside_compilation graph rewrite. Pull out enough of the...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Feb 2018 19:06:09 +0000 (11:06 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 19:19:17 +0000 (11:19 -0800)
END_PUBLIC

Fixed open source build breaks.

BEGIN_PUBLIC
Automated g4 rollback of changelist 184169668

PiperOrigin-RevId: 184306845

tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
tensorflow/core/framework/function.cc
tensorflow/core/framework/function.h

index 0de163d..9c372a0 100644 (file)
@@ -30,12 +30,14 @@ limitations under the License.
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph_def_util.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
@@ -141,8 +143,7 @@ struct NodeSlot {
 // everything to use it.
 static const char* const kArgOp = "_Arg";
 static const char* const kRetValOp = "_Retval";
-static const char* const kSendToHostOp = "_XlaSendToHost";
-static const char* const kRecvFromHostOp = "_XlaRecvFromHost";
+static const char* const kHostComputeOp = "_XlaHostCompute";
 static const char* const kSendFromHostOp = "_XlaSendFromHost";
 static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
 
@@ -171,7 +172,8 @@ class Encapsulator {
 
   // Write a copy of the input graph to 'graph_out', where the subgraphs are
   // replaced with calls to the new functions.
-  Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
+  Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
+                          FunctionLibraryDefinition* library);
 
  private:
   // A subgraph of the input, all marked with a common 'group_attribute'
@@ -201,21 +203,29 @@ class Encapsulator {
   //     ..             .
   //  RAH -->  C  --> SFH
   //
-  // The compiled cluster is as follows. STH is a SendToHost node which is the
-  // source of a channel to the RAH node above. RFH is a RecvFromHost node which
-  // is the destination of a channel from the SFH node above. There is a control
-  // edge that ensures RFH follows STH, which is used in shape inference to
-  // ensure that the shapes on the STH host channel are known before the RFH
-  // channel is compiled.
+  // The compiled cluster is as follows. HC is a HostCompute node which is the
+  // source of a channel to the RAH node above and the destination of a channel
+  // from the SFH node above.
   //
-  //  Arg  --> B  --> STH  ..>  RFH  --> D --> Retval
+  //  Arg  --> B  --> HC  --> D --> Retval
   //
-  // The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most
-  // one RAH and SFH in each compiled cluster. This design is preferred over
-  // adding separate Arg/Retval nodes for each transmitted value because it
-  // simplifies the host code that would like to limit communication between
-  // host and device and, e.g., raise only one interrupt per channel rather than
-  // one per transmitted value.
+  // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
+  // at most one RAH and SFH in each outside_compilation cluster. This design is
+  // preferred over adding separate Arg/Retval nodes for each transmitted value
+  // because it allows optimizations to the host code that would like to limit
+  // communication between host and device and, e.g., raise only one interrupt
+  // per channel rather than one per transmitted value.
+  //
+  // The shapes of the outputs from the HC node in general cannot be determined
+  // until the shapes of its inputs are known at compile time, since e.g.,
+  // above, the shape of C's outputs aren't known until the shape of its inputs
+  // are known. If the shapes of the HC's outputs can be determined during the
+  // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
+  // graph is stored in the shape_inference_graph attr. This graph can be used
+  // when compiling the HC Op to determined the shape of the SFH inputs given
+  // the shapes of any ancestor RAH outputs. If it can be determined that the
+  // shape of the SFH inputs will not be inferrable even once the shapes of the
+  // RAH outputs are known, an error is returned by the rewriter.
   class Subgraph {
    public:
     // Creates a graph to build the subgraph in, if it doesn't already exist,
@@ -246,6 +256,10 @@ class Encapsulator {
         const std::unordered_map<const Node*, Node*>& node_images,
         Graph* graph_out);
 
+    // Returns the names of all the outside_compilation subgraphs in this
+    // Subgraph.
+    void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
+
     // Returns the Node that inputs to the function should be wired up to.
     Node* GetCallNodeForInputs() const;
 
@@ -305,15 +319,9 @@ class Encapsulator {
     void RecordOutsideCompilationOutputOrControl(
         const string& outside_compilation_id, const Edge* edge);
 
-    // Adds the SendToHost nodes for each outside_compilation subgraph once the
-    // edges have all been recorded via RecordOutsideCompilationInputOrControl.
-    Status AddSendsToOutsideCompilation(
-        const std::unordered_map<const Node*, Node*>& node_images);
-
-    // Adds the RecvFromHost nodes for each outside_compilation subgraph once
-    // the edges have all been recorded via
-    // RecordOutsideCompilationOutputOrControl.
-    Status AddRecvsFromOutsideCompilation(
+    // Adds the HostCompute nodes for each outside_compilation subgraph.
+    Status AddHostComputes(
+        const string& subgraph_name,
         const std::unordered_map<const Node*, Node*>& node_images);
 
     // Creates the sequencer node if it doesn't exist, adding it to graph_out.
@@ -323,10 +331,16 @@ class Encapsulator {
     // all the downstream nodes of call_node_outputs.
     void ConnectSequencerToOutputs(Graph* graph_out);
 
+    Status AddShapeInferenceInfo(
+        const string& outside_compilation_subgraph_name,
+        const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
+
+    Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
+
    private:
     struct OutsideCompilationSubgraph {
       // Map from source (producer node/slot) tensors in the original graph to
-      // input index (slot number in the SendToHost/RecvAtHost nodes that will
+      // input index (slot number in the HostCompute/RecvAtHost nodes that will
       // be created) for the outside_compilation subgraph.
       std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
 
@@ -335,14 +349,14 @@ class Encapsulator {
       // outside_compilation subgraph. These are recorded by
       // RecordOutsideCompilationInputOrControl while walking all the subgraph
       // edges, and lifted control edges within the subgraph are added by
-      // AddSendsToOutsideCompilation once the _SendToHost node has been
+      // AddSendsToOutsideCompilation once the _HostCompute node has been
       // created. The matching control edge from _RecvAtHost to the
       // destination is added by CopyEdgeToOutputGraph.
       std::unordered_set<const Node*> control_inputs;
 
       // Maps from source (producer node/slot) and destination (consumer
       // node/slot) tensors in the original graph to output index (slot number
-      // in the SendFromHost/RecvFromHost nodes that will be created) for the
+      // in the SendFromHost/HostCompute nodes that will be created) for the
       // outside_compilation subgraph.
       std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
       std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
@@ -352,13 +366,13 @@ class Encapsulator {
       // containing compiled subgraph. These are recorded by
       // RecordOutsideCompilationOutputOrControl while walking all the subgraph
       // edges, and lifted control edges within the subgraph are added by
-      // AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been
+      // AddRecvsFromToOutsideCompilation once the _HostCompute node has been
       // created. The matching control edge from the source to _SendFromHost to
       // the destination is added by CopyEdgeToOutputGraph.
       std::unordered_set<const Node*> control_outputs;
 
-      // _SendToHost node in the subgraph. Not owned.
-      Node* send_to_host = nullptr;
+      // Name of the _HostCompute node in the subgraph.
+      string host_compute_name;
 
       // _RecvAtHost node in the output graph. Not owned.
       Node* recv_at_host = nullptr;
@@ -516,6 +530,59 @@ class Encapsulator {
       const std::unordered_map<const Node*, Node*>& node_images,
       bool parallel_checking, Graph* graph_out);
 
+  // Constructs a minimal shape inference graph that can be used to determine
+  // the shape of send_node at the time that the subgraph is compiled.
+  // recv_at_host_nodes contains the names of all the recv_at_host nodes that
+  // send_node might depend on. These recv_at_host nodes have shapes that are
+  // not known during the rewrite pass, but will be known at compile time.
+  //
+  // If the shapes of all the inputs to send_node can be determined during the
+  // rewrite pass, on exit graphdef_out is empty and the shapes are returned in
+  // static_shape_out. Otherwise graphdef_out contains a graph that can be used
+  // for shape inference at compile time, where all the source nodes of the
+  // graph are either constants with known shapes, or nodes named in
+  // recv_at_host_nodes.
+  //
+  // A non-OK status is returned if neither of the above conditions can be
+  // satisfied, e.g., because send_node depends on a node that doesn't have a
+  // registered shape inference function.
+  Status DoStaticShapeInferenceForOutsideCompilationSend(
+      const Graph& graph_in, const ShapeRefiner& shape_refiner,
+      const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
+      FunctionLibraryDefinition* library,
+      std::vector<TensorShapeProto>* static_shape_out,
+      std::unique_ptr<GraphDef>* graphdef_out);
+
+  // Makes a copy of graph containing only nodes that are ancestors of at least
+  // one node in send_from_host_nodes and store it in pruned_graph. On exit
+  // nodes_images contains a mapping from nodes in graph to nodes in
+  // pruned_graph. All functions in the copied graph are inlined.
+  Status MakePrunedGraphCopyAndInline(
+      const Graph& graph, const std::vector<Node*>& sink_nodes,
+      std::unique_ptr<Graph>* pruned_graph,
+      std::unordered_map<const Node*, Node*>* node_images,
+      FunctionLibraryDefinition* library);
+
+  // Makes a copy of graph containing only nodes that are ancestors of a
+  // send_from_host node in an outside_compilation subgraph, and store it in
+  // pruned_graph. Also perform shape inference on the pruned graph, using
+  // shape_refiner. On exit node_images contains a mapping from nodes in graph
+  // to nodes in pruned_graph.
+  Status MakeGraphForOutsideCompilationSends(
+      const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
+      ShapeRefiner* shape_refiner,
+      std::unordered_map<const Node*, Node*>* node_images,
+      FunctionLibraryDefinition* library);
+
+  // Performs static shape inference, as far as possible, for the send_from_host
+  // nodes in each outside_compilation subgraph. Where it is not possible to
+  // determine the shape statically, stores a serialized GraphDef in the
+  // HostCompute 'shape_inference_graph' attr, to be used at compile time for
+  // final inference. If the shapes are known statically they are stored in the
+  // HostCompute 'shapes' attr.
+  Status GetShapeInfoForOutsideCompilationSends(
+      Graph* graph_out, FunctionLibraryDefinition* library);
+
   const string group_attribute_;
   const string outside_compilation_attribute_;
   const Graph* graph_in_;
@@ -682,16 +749,20 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
   }
 }
 
-Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
+Status Encapsulator::Subgraph::AddHostComputes(
+    const string& subgraph_name,
     const std::unordered_map<const Node*, Node*>& node_images) {
   for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
     const string& oc_subgraph_name = oc_subgraph_iter.first;
     OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
-    if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
-      // Build a _SendToHost node sending all the args of the appropriate
-      // types.
-      std::vector<DataType> dtypes(oc_subgraph.inputs.size(), DT_INVALID);
+    if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
+        !oc_subgraph.outputs_by_src.empty() ||
+        !oc_subgraph.control_outputs.empty()) {
+      // Build a _HostCompute node.
       std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
+      std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
+      std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
+                                          DT_INVALID);
 
       for (const auto& input_src : oc_subgraph.inputs) {
         const Node* src_node = input_src.first.node;
@@ -700,94 +771,64 @@ Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
         int input_index = input_src.second;
 
         DataType dtype = src_node->output_type(src_slot);
-        dtypes[input_index] = dtype;
         inputs[input_index].Reset(src_image->name(), src_slot, dtype);
+        input_dtypes[input_index] = dtype;
       }
 
-      NodeDef send_def;
-      NodeDefBuilder builder(
-          strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"),
-          kSendToHostOp);
-      builder.Attr("dtypes", dtypes);
+      for (const auto& output : oc_subgraph.outputs_by_src) {
+        DataType dtype = output.first.dtype;
+        int output_index = output.second;
+        output_dtypes[output_index] = dtype;
+      }
+
+      NodeDef host_compute_def;
+      NodeDefBuilder builder(strings::StrCat("outside_compilation_",
+                                             oc_subgraph_name, "_host_compute"),
+                             kHostComputeOp);
       builder.Input(inputs);
-      Status s = builder.Finalize(&send_def);
+      builder.Attr("Tinputs", input_dtypes);
+      builder.Attr("Toutputs", output_dtypes);
+      builder.Attr("key",
+                   strings::StrCat("host_compute_channel_", subgraph_name, "_",
+                                   oc_subgraph_name));
+      Status s = builder.Finalize(&host_compute_def);
       if (!s.ok()) return s;
 
-      oc_subgraph.send_to_host = graph_->AddNode(send_def, &s);
+      Node* host_compute = graph_->AddNode(host_compute_def, &s);
       if (!s.ok()) return s;
+      oc_subgraph.host_compute_name = host_compute->name();
 
-      // Connect the _SendToHost node to its producers in the subgraph.
+      // Connect the _HostCompute node to its producers in the subgraph.
       for (auto& input_src : oc_subgraph.inputs) {
         const Node* src_node = input_src.first.node;
         Node* src_image = node_images.at(src_node);
         int src_slot = input_src.first.slot;
         int input_index = input_src.second;
-        graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host,
-                        input_index);
+        graph_->AddEdge(src_image, src_slot, host_compute, input_index);
       }
 
-      // Connect the _SendToHost node to its control edge producers in the
+      // Connect the _HostCompute node to its control edge producers in the
       // subgraph.
       for (const auto& src_node : oc_subgraph.control_inputs) {
         Node* src_image = node_images.at(src_node);
-        graph_->AddControlEdge(src_image, oc_subgraph.send_to_host);
-      }
-    }
-  }
-
-  return Status::OK();
-}
-
-Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation(
-    const std::unordered_map<const Node*, Node*>& node_images) {
-  for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
-    const string& oc_subgraph_name = oc_subgraph_iter.first;
-    OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
-    if (!oc_subgraph.outputs_by_src.empty() ||
-        !oc_subgraph.control_outputs.empty()) {
-      // Build a _RecvFromHost node producing all the outputs of the appropriate
-      // types.
-      std::vector<DataType> dtypes(oc_subgraph.outputs_by_src.size(),
-                                   DT_INVALID);
-
-      for (const auto& output : oc_subgraph.outputs_by_src) {
-        DataType dtype = output.first.dtype;
-        int output_index = output.second;
-        dtypes[output_index] = dtype;
+        graph_->AddControlEdge(src_image, host_compute);
       }
 
-      NodeDef recv_def;
-      NodeDefBuilder builder(
-          strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"),
-          kRecvFromHostOp);
-      builder.Attr("dtypes", dtypes);
-      Status s = builder.Finalize(&recv_def);
-      if (!s.ok()) return s;
-
-      Node* recv = graph_->AddNode(recv_def, &s);
-      if (!s.ok()) return s;
-
-      // Connect the consumers in the subgraph to the _RecvFromHost node.
+      // Connect the consumers in the subgraph to the _HostCompute node.
       for (const auto& output : oc_subgraph.outputs_by_dst) {
         const Node* dst_node = output.first.node;
         Node* dst_image = node_images.at(dst_node);
         int dst_slot = output.first.slot;
         int output_index = output.second;
 
-        graph_->AddEdge(recv, output_index, dst_image, dst_slot);
+        graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
       }
 
-      // Connect the control edge consumers in the subgraph to the _RecvFromHost
+      // Connect the control edge consumers in the subgraph to the _HostCompute
       // node.
       for (const auto& dst_node : oc_subgraph.control_outputs) {
         Node* dst_image = node_images.at(dst_node);
-        graph_->AddControlEdge(recv, dst_image);
-      }
-
-      // Add a control edge in the subgraph so that the _SendToHost node, if
-      // any, is compiled before the _RecvFromHost node.
-      if (oc_subgraph.send_to_host != nullptr) {
-        graph_->AddControlEdge(oc_subgraph.send_to_host, recv);
+        graph_->AddControlEdge(host_compute, dst_image);
       }
     }
   }
@@ -882,6 +923,63 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
   return Status::OK();
 }
 
+Status Encapsulator::Subgraph::AddShapeInferenceInfo(
+    const string& outside_compilation_subgraph_name,
+    const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
+  OutsideCompilationSubgraph& oc_subgraph =
+      outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
+
+  Node* host_compute = nullptr;
+  for (Node* n : graph_->nodes()) {
+    if (n->name() == oc_subgraph.host_compute_name) {
+      host_compute = n;
+      break;
+    }
+  }
+  if (host_compute == nullptr) {
+    return errors::InvalidArgument(
+        "After rewriting subgraph ", outside_compilation_subgraph_name,
+        " there is no HostCompute Op for outside compilation subgraph ",
+        oc_subgraph.host_compute_name);
+  }
+
+  if (inference_graph == nullptr) {
+    host_compute->AddAttr("shape_inference_graph", "");
+    host_compute->AddAttr("shapes", shapes);
+  } else {
+    string serialized_graph;
+    if (!inference_graph->SerializeToString(&serialized_graph)) {
+      return errors::Internal(
+          "Failed to serialize graph for outside compilation subgraph ",
+          oc_subgraph.host_compute_name);
+    }
+    host_compute->AddAttr("shape_inference_graph", serialized_graph);
+    host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
+  }
+  return Status::OK();
+}
+
+Status Encapsulator::Subgraph::ReplaceFunctionDef(
+    FunctionLibraryDefinition* library) {
+  const string& name = call_node_def_.name();
+
+  FunctionDef fdef;
+  TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
+
+  if (VLOG_IS_ON(1)) {
+    VLOG(2) << "Replace function def " << name;
+    dump_graph::DumpGraphToFile(
+        strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
+        library);
+    dump_graph::DumpFunctionDefToFile(
+        strings::StrCat("replace_encapsulate_fdef_", name), fdef);
+  }
+
+  TF_RETURN_IF_ERROR(library->RemoveFunction(name));
+  TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+  return Status::OK();
+}
+
 Status Encapsulator::Subgraph::BuildParallelCheckOp(
     const std::unordered_map<const Node*, Node*>& node_images,
     Graph* graph_out) {
@@ -980,7 +1078,9 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
   NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
                                          "_", oc_subgraph_name, "_recv"),
                          kRecvAtHostOp);
-  builder.Attr("dtypes", dtypes);
+  builder.Attr("Toutputs", dtypes);
+  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
+                                      "_", oc_subgraph_name));
   Status s = builder.Finalize(&recv_def);
   if (!s.ok()) return s;
 
@@ -1020,7 +1120,9 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
   NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
                                          "_", oc_subgraph_name, "_send"),
                          kSendFromHostOp);
-  builder.Attr("dtypes", dtypes);
+  builder.Attr("Tinputs", dtypes);
+  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
+                                      "_", oc_subgraph_name));
   builder.Input(inputs);
   Status s = builder.Finalize(&send_def);
   if (!s.ok()) return s;
@@ -1062,6 +1164,13 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
   return Status::OK();
 }
 
+void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
+    std::vector<string>* names) const {
+  for (auto& entry : outside_compilation_subgraphs_) {
+    names->push_back(entry.first);
+  }
+}
+
 Status Encapsulator::GetFunctionNameAttr(
     Node const* node, string* attr, string* outside_compilation_attr) const {
   Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
@@ -1220,8 +1329,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
   // single input and output node for it.
   for (auto& entry : subgraphs_) {
     Subgraph& subgraph = entry.second;
-    TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images));
-    TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images));
+    TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
   }
 
   MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
@@ -1509,8 +1617,346 @@ Status Encapsulator::AddEdgesToOutputGraph(
   return Status::OK();
 }
 
-Status Encapsulator::BuildOutputGraph(bool parallel_checking,
-                                      Graph* graph_out) {
+namespace {
+
+// Adds a dummy Const node to graph_out. The "constant" has the type of
+// data_type and the shape indicated in 'shape'. The dummy node is not a valid
+// Const node because it does not have any value defined, but this doesn't
+// matter because it will only be used subsequently for shape inference. (It
+// would be possible to add a switch statement over data_type to create a value
+// for the constant, but that would entail maintaining the logic as new types
+// are added, and is not necessary.)
+Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
+                         Graph* graph_out) {
+  TensorProto dummy_proto;
+  dummy_proto.set_dtype(data_type);
+  *dummy_proto.mutable_tensor_shape() = shape;
+  // Don't set any value field in the proto, since it is only going to be used
+  // for shape inference.
+
+  GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
+  NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
+                           options.op_registry());
+  node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
+  return options.FinalizeBuilder(&node_builder);
+}
+
+// Adds a copy of node_in to graph_out and adds the mapping to
+// copied_node_images.
+Status CopyShapeInferenceNodeToGraph(
+    Node* node_in, const Node* send_node,
+    const std::unordered_map<Node*, Node*>& dummy_node_images,
+    FunctionLibraryDefinition* library,
+    std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
+  // Once all the ancestor nodes have been added to graph_out, add this node
+  // and connect it to its ancestors.
+  Node* node_out = graph_out->CopyNode(node_in);
+  (*copied_node_images)[node_in] = node_out;
+  // Don't bother to build the shape inference graph if there's a node with no
+  // shape inference function, since it would just result in an error later at
+  // compile time.
+  const OpRegistrationData* op_reg_data;
+  TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
+  if (op_reg_data->shape_inference_fn == nullptr) {
+    return errors::InvalidArgument(
+        "Shape inference is not possible for outside_compilation "
+        "SendFromHost node ",
+        send_node->name(), " because it depends on node ", node_in->name(),
+        " which does not have a shape inference function registered.");
+  }
+  // Add all the edges to the newly copied node.
+  for (const Edge* in_edge : node_in->in_edges()) {
+    if (!in_edge->IsControlEdge()) {
+      Node* src = in_edge->src();
+      const auto iter = dummy_node_images.find(src);
+      if (iter == dummy_node_images.end()) {
+        // The src is a copied node so use the original output port.
+        graph_out->AddEdge((*copied_node_images)[in_edge->src()],
+                           in_edge->src_output(), node_out,
+                           in_edge->dst_input());
+      } else {
+        // The src is a dummy node so use output port 0.
+        graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
+      }
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
+    const Graph& graph_in, const ShapeRefiner& shape_refiner,
+    const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
+    FunctionLibraryDefinition* library,
+    std::vector<TensorShapeProto>* static_shape_out,
+    std::unique_ptr<GraphDef>* graphdef_out) {
+  // Maps from nodes in graph_in to nodes in graph_out.
+  //
+  // When an edge has fully defined shape the source node in graph_in is
+  // replaced in graph_out by a dummy constant node. The mapping from nodes
+  // in graph_in to dummy nodes is stored in dummy_node_images.
+  //
+  // When a node in graph_in has at least one ancestor that doesn't have fully
+  // defined shape, it is copied into graph_out. The mapping from nodes in
+  // graph_in to copied nodes is stored in copied_node_images.
+  //
+  // The two types of node are treated differently because, when adding edges to
+  // graph_out, an output from a dummy node always uses port 0, whereas an
+  // output from a copied node uses the same port that was used in graph_in.
+  std::unordered_map<Node*, Node*> dummy_node_images;
+  std::unordered_map<Node*, Node*> copied_node_images;
+
+  std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
+  graph_out->set_versions(graph_in.versions());
+  static_shape_out->resize(send_node->num_inputs());
+
+  // We don't use the standard ReverseDFS because we want to cut off traversal
+  // whenever we find an output with fully defined shape.
+  // TODO(misard) make this work properly in the presence of control flow.
+  struct Work {
+    Node* node;
+    bool leave;  // Are we entering or leaving node?
+  };
+  std::vector<Work> stack({{send_node, false}});
+  std::vector<bool> visited(graph_in.num_node_ids(), false);
+  while (!stack.empty()) {
+    Work w = stack.back();
+    stack.pop_back();
+    Node* n = w.node;
+
+    if (w.leave) {
+      TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
+          n, send_node, dummy_node_images, library, &copied_node_images,
+          graph_out.get()));
+    } else {
+      if (visited[n->id()]) continue;
+      visited[n->id()] = true;
+
+      // Arrange to revisit when all done with all inputs.
+      stack.push_back(Work{n, true});
+
+      bool has_parent_with_unknown_shape = false;
+      for (const Edge* in_edge : n->in_edges()) {
+        if (!in_edge->IsControlEdge()) {
+          Node* src_node = in_edge->src();
+          int src_port = in_edge->src_output();
+          shape_inference::InferenceContext* context =
+              shape_refiner.GetContext(src_node);
+          shape_inference::ShapeHandle shape = context->output(src_port);
+          if (context->FullyDefined(shape)) {
+            // This ancestor has known shape, so instead of adding it to the
+            // stack, add a dummy node with that shape to graph_out and
+            // continue.
+            TensorShapeProto proto;
+            context->ShapeHandleToProto(shape, &proto);
+            dummy_node_images[src_node] = AddDummyShapedNode(
+                src_node->output_type(src_port), proto, graph_out.get());
+            if (n == send_node) {
+              (*static_shape_out)[in_edge->dst_input()] = proto;
+            }
+          } else {
+            if (!visited[src_node->id()]) {
+              has_parent_with_unknown_shape = true;
+              stack.push_back({src_node, false});
+            }
+          }
+        }
+      }
+      if (!has_parent_with_unknown_shape) {
+        if (n == send_node) {
+          // The shapes of all the inputs to send_node are statically known. We
+          // won't have to do any inference at compile time so return now: the
+          // shapes were stored in static_shape_out above.
+          graphdef_out->reset();
+          return Status::OK();
+        } else {
+          // Any shape that is being processed is either the original send node
+          // or has at least one output with statically-unknown shape. If the
+          // latter and it doesn't have any inputs with statically-unknown
+          // shape, then check that it is of the recv nodes that we can fill in
+          // the shape of at run-time later. If it isn't one of those, then we
+          // won't have any additional knowledge at compile time, so we already
+          // know we won't be able to do shape inference and we can return an
+          // error now.
+          if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
+            return errors::InvalidArgument(
+                "Shape inference is not possible for outside_compilation "
+                "SendFromHost node ",
+                send_node->name(), " because shape of node ", n->name(),
+                " will not be known at compilation time.");
+          }
+        }
+      }
+    }
+  }
+
+  graphdef_out->reset(new GraphDef());
+  graph_out->ToGraphDef(graphdef_out->get());
+
+  return Status::OK();
+}
+
+Status Encapsulator::MakePrunedGraphCopyAndInline(
+    const Graph& graph, const std::vector<Node*>& sink_nodes,
+    std::unique_ptr<Graph>* pruned_graph,
+    std::unordered_map<const Node*, Node*>* node_images,
+    FunctionLibraryDefinition* library) {
+  // First copy all ancestor nodes of sink_nodes into a new graph.
+  pruned_graph->reset(new Graph(library));
+  (*pruned_graph)->set_versions(graph.versions());
+  ReverseDFSFrom(graph, sink_nodes,
+                 /*enter=*/nullptr,
+                 /*leave=*/[&](Node* n) {
+                   if (!n->IsSource()) {
+                     Node* copied = (*pruned_graph)->CopyNode(n);
+                     node_images->emplace(n, copied);
+                   }
+                 });
+
+  // Add all the edges between copied nodes.
+  for (auto entry : *node_images) {
+    const Node* orig = entry.first;
+    Node* image = entry.second;
+    for (const Edge* out_edge : orig->out_edges()) {
+      auto iter = node_images->find(out_edge->dst());
+      if (iter != node_images->end()) {
+        // The source and destination are both in the copied graph.
+        (*pruned_graph)
+            ->AddEdge(image, out_edge->src_output(), iter->second,
+                      out_edge->dst_input());
+      }
+    }
+  }
+
+  // Find all the function call nodes, and inline them.
+  std::vector<Node*> function_nodes;
+  for (auto node : (*pruned_graph)->nodes()) {
+    const OpRegistrationData* op_reg_data;
+    TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
+    if (op_reg_data->is_function_op) {
+      function_nodes.push_back(node);
+    }
+  }
+  for (auto node : function_nodes) {
+    VLOG(2) << "Inlining function " << node->name();
+    const FunctionDef* fdef = library->Find(node->type_string());
+    if (fdef == nullptr) {
+      return errors::Internal("Failed to find function ", node->type_string(),
+                              " in function library.");
+    }
+    FunctionBody* fbody = nullptr;
+    TF_RETURN_IF_ERROR(
+        FunctionDefToBodyHelper(*fdef, node->attrs(), library,
+                                [library](const string& op, const OpDef** sig) {
+                                  return library->LookUpOpDef(op, sig);
+                                },
+                                &fbody));
+    InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
+    delete fbody;
+  }
+
+  return Status::OK();
+}
+
+Status Encapsulator::MakeGraphForOutsideCompilationSends(
+    const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
+    ShapeRefiner* shape_refiner,
+    std::unordered_map<const Node*, Node*>* node_images,
+    FunctionLibraryDefinition* library) {
+  // Find all the send_from_host nodes in all subgraphs, to use as roots for the
+  // pruning.
+  std::vector<Node*> send_from_host_nodes;
+  for (auto& subgraph_entry : subgraphs_) {
+    Subgraph& subgraph = subgraph_entry.second;
+    std::vector<string> outside_compilation_names;
+    subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
+    for (const auto& name : outside_compilation_names) {
+      Node* send_node = subgraph.GetSendFromHostNode(name);
+      if (send_node != nullptr) {
+        send_from_host_nodes.push_back(send_node);
+      }
+    }
+  }
+
+  // Make a copy of all the graph nodes needed to evaluate the send_from_host
+  // nodes, inlining any functions as needed.
+  TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
+      graph, send_from_host_nodes, pruned_graph, node_images, library));
+
+  // Perform shape inference on the pruned graph.
+  shape_refiner->set_require_shape_inference_fns(false);
+  FixupSourceAndSinkEdges(pruned_graph->get());
+  std::vector<Node*> post_order;
+  GetReversePostOrder(*(*pruned_graph), &post_order);
+  for (auto node : post_order) {
+    // Ignore the status returned by the shape_refiner. At this point we want
+    // the best effort shapes, even if no shape function is registered for a
+    // node.
+    Status status = shape_refiner->AddNode(node);
+    if (!status.ok()) {
+      VLOG(1) << "Shape inference failed for node: " << status;
+    }
+  }
+
+  return Status::OK();
+}
+
+Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
+    Graph* graph_out, FunctionLibraryDefinition* library) {
+  std::unique_ptr<Graph> pruned_graph;
+  ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
+  std::unordered_map<const Node*, Node*> node_images;
+  TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
+      *graph_out, &pruned_graph, &shape_refiner, &node_images, library));
+
+  for (auto& subgraph_entry : subgraphs_) {
+    Subgraph& subgraph = subgraph_entry.second;
+    // Find all the recv_at_host nodes in this subgraph.
+    std::vector<string> outside_compilation_names;
+    subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
+    std::unordered_set<string> recv_at_host_names;
+    for (const auto& name : outside_compilation_names) {
+      Node* recv_node = subgraph.GetRecvAtHostNode(name);
+      if (recv_node != nullptr) {
+        recv_at_host_names.insert(recv_node->name());
+      }
+    }
+    // For each send_from_host node, do as much shape inference as possible
+    // without knowing the shape of the recv_at_host nodes, and store the
+    // result, along with enough information to complete the job at compile time
+    // once the recv_at_host shapes are known.
+    for (const auto& name : outside_compilation_names) {
+      Node* send_node = subgraph.GetSendFromHostNode(name);
+      std::vector<TensorShapeProto> static_shape;
+      std::unique_ptr<GraphDef> graphdef;
+      if (send_node != nullptr) {
+        TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
+            *pruned_graph, shape_refiner, recv_at_host_names,
+            node_images[send_node], library, &static_shape, &graphdef));
+        if (graphdef == nullptr) {
+          VLOG(2) << "Send node  " << send_node->name() << " shapes";
+          for (int i = 0; i < static_shape.size(); ++i) {
+            VLOG(2) << static_shape[i].DebugString();
+          }
+        } else {
+          VLOG(2) << "Send node " << send_node->name() << " graph\n"
+                  << graphdef->DebugString();
+        }
+      }
+      TF_RETURN_IF_ERROR(
+          subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
+    }
+    if (!outside_compilation_names.empty()) {
+      TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
+    }
+  }
+
+  return Status::OK();
+}
+
+Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
+                                      FunctionLibraryDefinition* library) {
   // Map from nodes in the input graph to nodes in the output graph.
   std::unordered_map<const Node*, Node*> node_images;
 
@@ -1522,6 +1968,9 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking,
   TF_RETURN_IF_ERROR(
       AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
 
+  TF_RETURN_IF_ERROR(
+      GetShapeInfoForOutsideCompilationSends(graph_out, library));
+
   return Status::OK();
 }
 
@@ -1545,7 +1994,7 @@ Status EncapsulateSubgraphsInFunctions(
   std::unique_ptr<Graph> out(new Graph(library));
   out->set_versions(graph_in.versions());
   TF_RETURN_IF_ERROR(
-      encapsulator.BuildOutputGraph(parallel_checking, out.get()));
+      encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
 
   *graph_out = std::move(out);
   return Status::OK();
index b100861..aed9cae 100644 (file)
@@ -29,17 +29,181 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
+template <class Tkey, class Tvalue>
+bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
+                   const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
+                   const std::function<string(const Tkey&)>& key_to_string,
+                   const std::function<string(const Tvalue&)>& value_to_string,
+                   const std::function<bool(const Tkey&, const Tvalue&,
+                                            const Tvalue&)>& compare,
+                   const string& map_name, string* diff) {
+  for (const auto& elt_a : a) {
+    const auto iter = b.find(elt_a.first);
+    if (iter == b.end()) {
+      if (diff) {
+        *diff = strings::StrCat(
+            map_name, " expected: contains element with key '",
+            key_to_string(elt_a.first), "' got: map has no such element");
+      }
+      return false;
+    }
+    if (!compare(elt_a.first, elt_a.second, iter->second)) {
+      if (diff) {
+        *diff = strings::StrCat(map_name, " expected: element with key '",
+                                key_to_string(elt_a.first), " has value '",
+                                value_to_string(elt_a.second), "' got: '",
+                                value_to_string(iter->second), "'");
+      }
+      return false;
+    }
+  }
+  for (const auto& elt_b : b) {
+    const auto iter = a.find(elt_b.first);
+    if (iter == a.end()) {
+      if (diff) {
+        *diff = strings::StrCat(map_name, " got: contains element with key '",
+                                key_to_string(elt_b.first),
+                                "' expected: map has no such element");
+      }
+      return false;
+    }
+  }
+  return true;
+}
+
+bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
+                          const string& diff_preamble, string* diff) {
+  if (a.op() != b.op()) {
+    if (diff) {
+      *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                              ", expected op '", a.op(), "' got '", b.op());
+    }
+    return false;
+  }
+  if (a.device() != b.device()) {
+    if (diff) {
+      *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                              ", expected device '", a.device(), "' got '",
+                              b.device());
+    }
+    return false;
+  }
+  if (a.input_size() != b.input_size()) {
+    if (diff) {
+      *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                              ", expected ", a.input_size(), " inputs got ",
+                              b.input_size(), " expected:\n", a.DebugString(),
+                              "\ngot:\n", b.DebugString());
+    }
+    return false;
+  }
+  for (int i = 0; i < a.input_size(); ++i) {
+    if (a.input(i) != b.input(i)) {
+      if (diff) {
+        *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                                " input ", i, ", expected ", a.input(i),
+                                " got ", b.input(i), " expected:\n",
+                                a.DebugString(), "\ngot:\n", b.DebugString());
+      }
+      return false;
+    }
+  }
+  return EqualProtoMap<string, AttrValue>(
+      a.attr(), b.attr(), [](const string& s) { return s; },
+      [](const AttrValue& v) { return v.DebugString(); },
+      [](const string& key, const AttrValue& av, const AttrValue& bv) {
+        if (key == "shape_inference_graph") {
+          // Default serialization of GraphDef is unstable because maps don't
+          // serialize deterministically. Rather than go through the hoops to
+          // turn on deterministic serialization of this attr just for this
+          // test, add logic here to compare determinstically.
+          GraphDef ga;
+          if (!ga.ParseFromString(av.s())) {
+            return false;
+          }
+          GraphDef gb;
+          if (!gb.ParseFromString(bv.s())) {
+            return false;
+          }
+          return EqualGraphDef(ga, gb, nullptr);
+        } else {
+          return av.DebugString() == bv.DebugString();
+        }
+      },
+      strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
+      diff);
+}
+
 bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
                       string* diff) {
-  // TODO(phawkins) use a more sophisticated equality test.
-  if (a.DebugString() != b.DebugString()) {
+  if (a.signature().DebugString() != b.signature().DebugString()) {
     if (diff) {
-      *diff = strings::StrCat("Definition mismatch for function ",
+      *diff = strings::StrCat("Signature mismatch for function ",
                               a.signature().name(), ", expected:\n",
-                              a.DebugString(), "\ngot:\n", b.DebugString());
+                              a.signature().DebugString(), "\ngot:\n",
+                              b.signature().DebugString());
     }
     return false;
   }
+  if (!EqualProtoMap<string, AttrValue>(
+          a.attr(), b.attr(), [](const string& s) { return s; },
+          [](const AttrValue& v) { return v.DebugString(); },
+          [](const string& key, const AttrValue& av, const AttrValue& bv) {
+            return av.DebugString() == bv.DebugString();
+          },
+          strings::StrCat("attr mismatch for function ", a.signature().name()),
+          diff)) {
+    return false;
+  }
+  if (!EqualProtoMap<string, string>(
+          a.ret(), b.ret(), [](const string& s) { return s; },
+          [](const string& s) { return s; },
+          [](const string& key, const string& av, const string& bv) {
+            return av == bv;
+          },
+          strings::StrCat("ret mismatch for function ", a.signature().name()),
+          diff)) {
+    return false;
+  }
+  for (int i = 0; i < a.node_def_size(); ++i) {
+    bool found = false;
+    for (int j = 0; j < b.node_def_size(); ++j) {
+      if (a.node_def(i).name() == b.node_def(j).name()) {
+        if (!EqualFunctionNodeDef(
+                a.node_def(i), b.node_def(j),
+                strings::StrCat("Function ", a.signature().name()), diff)) {
+          return false;
+        }
+        found = true;
+        break;
+      }
+    }
+    if (!found) {
+      if (diff) {
+        *diff = strings::StrCat("Function ", a.signature().name(),
+                                ", expected: has node '", a.node_def(i).name(),
+                                "' got: no node of that name");
+      }
+      return false;
+    }
+  }
+  for (int i = 0; i < b.node_def_size(); ++i) {
+    bool found = false;
+    for (int j = 0; j < a.node_def_size(); ++j) {
+      if (b.node_def(i).name() == a.node_def(j).name()) {
+        found = true;
+        break;
+      }
+    }
+    if (!found) {
+      if (diff) {
+        *diff = strings::StrCat("Function ", a.signature().name(),
+                                ", got: has node '", b.node_def(i).name(),
+                                "' expected: no node of that name");
+      }
+      return false;
+    }
+  }
   return true;
 }
 
@@ -84,29 +248,64 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
 
 // TODO(misard): remove these fake registrations once there are real Ops to be
 // compiled.
-REGISTER_OP("_XlaSendToHost")
-    .Input("input: dtypes")
-    .Attr("dtypes: list(type) >= 0");
-
-REGISTER_OP("_XlaRecvFromHost")
-    .Output("output: dtypes")
-    .Attr("dtypes: list(type) >= 0");
+REGISTER_OP("_XlaHostCompute")
+    .Input("inputs: Tinputs")
+    .Output("outputs: Toutputs")
+    .Attr("Tinputs: list(type) >= 0")
+    .Attr("Toutputs: list(type) >= 0")
+    .Attr("key: string")
+    .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
 
 REGISTER_OP("_XlaSendFromHost")
-    .Input("input: dtypes")
-    .Attr("dtypes: list(type) >= 0");
+    .Input("input: Tinputs")
+    .Attr("Tinputs: list(type) >= 0")
+    .Attr("key: string")
+    .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
 
 REGISTER_OP("_XlaRecvAtHost")
-    .Output("output: dtypes")
-    .Attr("dtypes: list(type) >= 0");
-
-REGISTER_OP("InputTest").Output("o: float");
-
-REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
+    .Output("output: Toutputs")
+    .Attr("Toutputs: list(type) >= 0")
+    .Attr("key: string")
+    .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
+
+REGISTER_OP("InputTest")
+    .Output("o: float")
+    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+      c->set_output(0, c->UnknownShape());
+      return Status::OK();
+    });
+
+REGISTER_OP("InputTestShaped")
+    .Output("o: float")
+    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+      c->set_output(0, c->Vector(2));
+      return Status::OK();
+    });
+
+REGISTER_OP("UnaryTest")
+    .Input("a: float")
+    .Output("o: float")
+    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+      ::tensorflow::shape_inference::ShapeHandle o;
+      TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
+      c->set_output(0, o);
+      return Status::OK();
+    });
 REGISTER_OP("BinaryTest")
     .Input("a: float")
     .Input("b: float")
-    .Output("o: float");
+    .Output("o: float")
+    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+      ::tensorflow::shape_inference::ShapeHandle o;
+      TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
+      c->set_output(0, o);
+      return Status::OK();
+    });
+REGISTER_OP("BinaryTest2")
+    .Input("a: float")
+    .Input("b: float")
+    .Output("o: float")
+    .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
 
 REGISTER_OP("AddNLikeTest")
     .Input("inputs: N * T")
@@ -124,22 +323,48 @@ Node* Input(const GraphDefBuilder::Options& opts) {
   return ops::SourceOp("InputTest", opts);
 }
 
-Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
+Node* InputShaped(const GraphDefBuilder::Options& opts) {
+  return ops::SourceOp("InputTestShaped", opts);
+}
+
+Node* KnownShape(const gtl::ArraySlice<int>& shape,
+                 const GraphDefBuilder::Options& opts) {
+  if (opts.HaveError()) return nullptr;
+  NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
+                           opts.op_registry());
+  TensorProto value;
+  value.set_dtype(DT_FLOAT);
+  for (int dim : shape) {
+    value.mutable_tensor_shape()->add_dim()->set_size(dim);
+  }
+  return opts.WithAttr("value", value)
+      .WithAttr("dtype", DT_FLOAT)
+      .FinalizeBuilder(&node_builder);
+}
+
+Node* RecvAtHost(const string& key, const gtl::ArraySlice<DataType>& dtypes,
                  const GraphDefBuilder::Options& opts) {
   if (opts.HaveError()) return nullptr;
   NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
                            "_XlaRecvAtHost", opts.op_registry());
-  return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
+  return opts.WithAttr("Toutputs", dtypes)
+      .WithAttr("key", key)
+      .FinalizeBuilder(&node_builder);
 }
 
-Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
-                   const gtl::ArraySlice<DataType>& dtypes,
+Node* SendFromHost(const string& key, const std::vector<ops::NodeOut>& inputs,
                    const GraphDefBuilder::Options& opts) {
   if (opts.HaveError()) return nullptr;
   NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
                            "_XlaSendFromHost", opts.op_registry());
   node_builder.Input(inputs);
-  return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
+  std::vector<DataType> dtypes;
+  for (const auto& node : inputs) {
+    dtypes.push_back(node.dt);
+  }
+  return opts.WithAttr("key", key)
+      .WithAttr("Tinputs", dtypes)
+      .FinalizeBuilder(&node_builder);
 }
 
 Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
@@ -151,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b,
   return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
 }
 
+Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
+                         const GraphDefBuilder::Options& opts) {
+  return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
+}
+
 Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
                const GraphDefBuilder::Options& opts) {
   if (opts.HaveError()) return nullptr;
@@ -576,6 +806,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
+  string shape_string_expected;
+  {
+    GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
+    Node* recv =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+                   shape.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
+                     shape.opts().WithName("E"));
+    SendFromHost("host_compute_channel_F1_O1", {e},
+                 shape.opts().WithName("outside_compilation_F1_O1_send"));
+    GraphDef shape_graph;
+    TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
+    EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+  }
+
   *library_expected.add_function() = test::function::XTimesTwo();
   *library_expected.add_function() = FunctionDefHelper::Create(
       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
@@ -584,19 +829,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
           {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
           {{"F"},
            "BinaryTest",
-           {"C:o:0", "outside_compilation_O1_recv:output:0"},
+           {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
            {},
-           {"outside_compilation_O1_recv"}},
-          {{"outside_compilation_O1_send"},
-           "_XlaSendToHost",
+           {"outside_compilation_O1_host_compute"}},
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
            {"C:o:0", "c:o:0"},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", shape_string_expected},
+            {"shapes", gtl::ArraySlice<DataType>({})}},
            {"c"}},
-          {{"outside_compilation_O1_recv"},
-           "_XlaRecvFromHost",
-           {},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
-           {"outside_compilation_O1_send"}},
       },
       {{"f_0_retval", "F:o:0"}});
 
@@ -612,11 +856,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
     Node* call = b2.opts().FinalizeBuilder(&node_builder);
 
     Node* recv =
-        RecvAtHost({DT_FLOAT, DT_FLOAT},
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
                      b2.opts().WithName("E").WithControlInputs({recv, b}));
-    Node* send = SendFromHost({e}, {DT_FLOAT},
+    Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
                               b2.opts()
                                   .WithName("outside_compilation_F1_O1_send")
                                   .WithControlInput(e));
@@ -674,37 +918,71 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
+  string shape_string_expected_1;
+  {
+    GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+    Node* recv =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+                   shape1.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
+                     shape1.opts().WithName("E"));
+    SendFromHost("host_compute_channel_F1_O1", {e},
+                 shape1.opts().WithName("outside_compilation_F1_O1_send"));
+    GraphDef shape1_graph;
+    TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
+    EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
+  }
+
+  string shape_string_expected_2;
+  {
+    GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
+    Node* recv1 =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+                   shape2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
+                     shape2.opts().WithName("E"));
+    Node* recv2 =
+        RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
+                   shape2.opts().WithName("outside_compilation_F1_O2_recv"));
+    Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
+    SendFromHost("host_compute_channel_F1_O2", {h},
+                 shape2.opts().WithName("outside_compilation_F1_O2_send"));
+    GraphDef shape2_graph;
+    TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
+    EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
+  }
+
   *library_expected.add_function() = FunctionDefHelper::Create(
       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
       {
           {{"C"}, "UnaryTest", {"a_0_arg"}},
           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
-          {{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
+          {{"I"},
+           "UnaryTest",
+           {"outside_compilation_O2_host_compute:outputs:0"}},
           {{"F"},
            "BinaryTest",
-           {"C:o:0", "outside_compilation_O1_recv:output:0"},
+           {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
            {},
-           {"outside_compilation_O1_recv"}},
-          {{"outside_compilation_O2_send"},
-           "_XlaSendToHost",
+           {"outside_compilation_O1_host_compute"}},
+          {{"outside_compilation_O2_host_compute"},
+           "_XlaHostCompute",
            {"D:o:0", "F:o:0"},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F1_O2"},
+            {"shape_inference_graph", shape_string_expected_2},
+            {"shapes", gtl::ArraySlice<DataType>({})}},
            {"F"}},
-          {{"outside_compilation_O1_send"},
-           "_XlaSendToHost",
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
            {"C:o:0", "D:o:0"},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", shape_string_expected_1},
+            {"shapes", gtl::ArraySlice<DataType>({})}},
            {"D"}},
-          {{"outside_compilation_O2_recv"},
-           "_XlaRecvFromHost",
-           {},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
-           {"outside_compilation_O2_send"}},
-          {{"outside_compilation_O1_recv"},
-           "_XlaRecvFromHost",
-           {},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
-           {"outside_compilation_O1_send"}},
       },
       {{"i_0_retval", "I:o:0"}});
 
@@ -720,23 +998,24 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
     Node* call = b2.opts().FinalizeBuilder(&node_builder);
 
     Node* recv1 =
-        RecvAtHost({DT_FLOAT, DT_FLOAT},
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
                      b2.opts().WithName("E").WithControlInputs({recv1, b}));
-    Node* send1 = SendFromHost({e}, {DT_FLOAT},
+    Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
                                b2.opts()
                                    .WithName("outside_compilation_F1_O1_send")
                                    .WithControlInput(e));
 
     Node* recv2 =
-        RecvAtHost({DT_FLOAT, DT_FLOAT},
+        RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
                    b2.opts().WithName("outside_compilation_F1_O2_recv"));
     Node* g = Binary(e, ops::NodeOut(recv2, 1),
                      b2.opts().WithName("G").WithControlInputs({recv2, e}));
     Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
-    Node* send2 = SendFromHost(
-        {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
+    Node* send2 =
+        SendFromHost("host_compute_channel_F1_O2", {h},
+                     b2.opts().WithName("outside_compilation_F1_O2_send"));
 
     Node* s = NoOp(b2.opts()
                        .WithName("F1_sequencer")
@@ -758,8 +1037,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
 
   {
     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
-    Node* a = Input(b1.opts().WithName("A"));
-    Node* b = Input(b1.opts().WithName("B"));
+    Node* a = InputShaped(b1.opts().WithName("A"));
+    Node* b = InputShaped(b1.opts().WithName("B"));
     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
     Node* d =
         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
@@ -791,6 +1070,24 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
+  string shape_string_expected;
+  {
+    GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
+    Node* recv =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
+                   shape.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
+                     shape.opts().WithName("E"));
+    SendFromHost("host_compute_channel_F1_O1", {e},
+                 shape.opts().WithName("outside_compilation_F1_O1_send"));
+    GraphDef shape_graph;
+    TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
+    EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+  }
+
+  TensorShapeProto shape_proto_expected;
+  shape_proto_expected.add_dim()->set_size(2);
+
   *library_expected.add_function() = FunctionDefHelper::Create(
       "F1", {"a_0_arg:float", "b_0_arg:float"},
       {"f_0_retval:float", "d_0_retval:float"}, {},
@@ -799,19 +1096,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
           {{"F"},
            "BinaryTest",
-           {"C:o:0", "outside_compilation_O1_recv:output:0"},
+           {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
            {},
-           {"outside_compilation_O1_recv"}},
-          {{"outside_compilation_O1_send"},
-           "_XlaSendToHost",
+           {"outside_compilation_O1_host_compute"}},
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
            {"C:o:0", "D:o:0"},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", shape_string_expected},
+            {"shapes", gtl::ArraySlice<DataType>({})}},
            {"D"}},
-          {{"outside_compilation_O1_recv"},
-           "_XlaRecvFromHost",
-           {},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
-           {"outside_compilation_O1_send"}},
       },
       {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
 
@@ -822,16 +1118,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
           {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
           {{"I"},
            "BinaryTest",
-           {"f_0_arg", "outside_compilation_O1_recv:output:0"}},
-          {{"outside_compilation_O1_send"},
-           "_XlaSendToHost",
+           {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
            {"G:o:0"},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
-          {{"outside_compilation_O1_recv"},
-           "_XlaRecvFromHost",
-           {},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
-           {"outside_compilation_O1_send"}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F2_O1"},
+            {"shape_inference_graph", ""},
+            {"shapes",
+             gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
       },
       {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
 
@@ -839,15 +1135,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
     std::unique_ptr<FunctionLibraryDefinition> lib_def(
         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
-    Node* a = Input(b2.opts().WithName("A"));
-    Node* b = Input(b2.opts().WithName("B"));
+    Node* a = InputShaped(b2.opts().WithName("A"));
+    Node* b = InputShaped(b2.opts().WithName("B"));
 
     Node* recv1 =
-        RecvAtHost({DT_FLOAT, DT_FLOAT},
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
                      b2.opts().WithName("E").WithControlInputs({recv1, b}));
-    Node* send1 = SendFromHost({e}, {DT_FLOAT},
+    Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
                                b2.opts()
                                    .WithName("outside_compilation_F1_O1_send")
                                    .WithControlInput(e));
@@ -857,12 +1153,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
     Node* s1 = NoOp(
         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
 
-    Node* recv2 = RecvAtHost(
-        {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
+    Node* recv2 =
+        RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT},
+                   b2.opts().WithName("outside_compilation_F2_O1_recv"));
     Node* h = Binary(ops::NodeOut(call1, 1), recv2,
                      b2.opts().WithName("H").WithControlInput(s1));
-    Node* send2 = SendFromHost(
-        {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
+    Node* send2 =
+        SendFromHost("host_compute_channel_F2_O1", {h},
+                     b2.opts().WithName("outside_compilation_F2_O1_send"));
 
     NodeBuilder node_builder2("F2", "F2", lib_def.get());
     node_builder2.Input(e).Input(call1);
@@ -888,7 +1186,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
 
   {
     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
-    Node* a = Input(b1.opts().WithName("A"));
+    Node* a = InputShaped(b1.opts().WithName("A"));
     Node* b = Input(b1.opts().WithName("B"));
     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
     Node* d =
@@ -908,6 +1206,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
+  TensorShapeProto shape_proto_expected;
+  shape_proto_expected.add_dim()->set_size(2);
+
   *library_expected.add_function() = FunctionDefHelper::Create(
       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
       {
@@ -915,11 +1216,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
           {{"F"},
            "BinaryTest",
-           {"D:o:0", "outside_compilation_O1_recv:output:0"}},
-          {{"outside_compilation_O1_recv"},
-           "_XlaRecvFromHost",
+           {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
            {},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", ""},
+            {"shapes",
+             gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
       },
       {{"f_0_retval", "F:o:0"}});
 
@@ -927,12 +1233,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
     std::unique_ptr<FunctionLibraryDefinition> lib_def(
         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
-    Node* a = Input(b2.opts().WithName("A"));
+    Node* a = InputShaped(b2.opts().WithName("A"));
     Node* b = Input(b2.opts().WithName("B"));
 
     Node* e = Unary(a, b2.opts().WithName("E"));
-    Node* send1 = SendFromHost(
-        {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
+    Node* send1 =
+        SendFromHost("host_compute_channel_F1_O1", {e},
+                     b2.opts().WithName("outside_compilation_F1_O1_send"));
     NodeBuilder node_builder1("F1", "F1", lib_def.get());
     node_builder1.Input(a).Input(b);
     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
@@ -954,7 +1261,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
 
   {
     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
-    Node* a = Input(b1.opts().WithName("A"));
+    Node* a = InputShaped(b1.opts().WithName("A"));
     Node* b = Input(b1.opts().WithName("B"));
     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
     Node* d =
@@ -975,6 +1282,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
+  TensorShapeProto shape_proto_expected;
+  shape_proto_expected.add_dim()->set_size(2);
+
   *library_expected.add_function() = FunctionDefHelper::Create(
       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
       {
@@ -982,17 +1292,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
           {{"F"},
            "BinaryTest",
-           {"D:o:0", "outside_compilation_O1_recv:output:0"}},
-          {{"outside_compilation_O1_send"},
-           "_XlaSendToHost",
+           {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
            {},
-           {{"dtypes", gtl::ArraySlice<DataType>({})}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", ""},
+            {"shapes",
+             gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}},
            {"D"}},
-          {{"outside_compilation_O1_recv"},
-           "_XlaRecvFromHost",
-           {},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
-           {"outside_compilation_O1_send"}},
       },
       {{"f_0_retval", "F:o:0"}});
 
@@ -1000,14 +1310,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
     std::unique_ptr<FunctionLibraryDefinition> lib_def(
         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
-    Node* a = Input(b2.opts().WithName("A"));
+    Node* a = InputShaped(b2.opts().WithName("A"));
     Node* b = Input(b2.opts().WithName("B"));
 
     Node* recv1 =
-        RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
+        RecvAtHost("host_compute_channel_F1_O1", {},
+                   b2.opts().WithName("outside_compilation_F1_O1_recv"));
     Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
-    Node* send1 = SendFromHost(
-        {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
+    Node* send1 =
+        SendFromHost("host_compute_channel_F1_O1", {e},
+                     b2.opts().WithName("outside_compilation_F1_O1_send"));
     NodeBuilder node_builder1("F1", "F1", lib_def.get());
     node_builder1.Input(a).Input(b);
     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
@@ -1055,10 +1367,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
           {{"C"}, "UnaryTest", {"a_0_arg"}},
           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
           {{"F"}, "UnaryTest", {"D:o:0"}},
-          {{"outside_compilation_O1_send"},
-           "_XlaSendToHost",
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
            {"D:o:0"},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", ""},
+            {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
       },
       {{"f_0_retval", "F:o:0"}});
 
@@ -1069,8 +1385,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
     Node* a = Input(b2.opts().WithName("A"));
     Node* b = Input(b2.opts().WithName("B"));
 
-    Node* recv1 = RecvAtHost(
-        {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv1 =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+                   b2.opts().WithName("outside_compilation_F1_O1_recv"));
     Node* e = Unary(recv1, b2.opts().WithName("E"));
     NodeBuilder node_builder1("F1", "F1", lib_def.get());
     node_builder1.Input(a).Input(b);
@@ -1118,16 +1435,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
       {
           {{"C"}, "UnaryTest", {"a_0_arg"}},
           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
-          {{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
-          {{"outside_compilation_O1_send"},
-           "_XlaSendToHost",
+          {{"F"},
+           "UnaryTest",
            {"D:o:0"},
-           {{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
-          {{"outside_compilation_O1_recv"},
-           "_XlaRecvFromHost",
            {},
-           {{"dtypes", gtl::ArraySlice<DataType>({})}},
-           {"outside_compilation_O1_send"}},
+           {"outside_compilation_O1_host_compute"}},
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
+           {"D:o:0"},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", ""},
+            {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
       },
       {{"f_0_retval", "F:o:0"}});
 
@@ -1138,10 +1458,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
     Node* a = Input(b2.opts().WithName("A"));
     Node* b = Input(b2.opts().WithName("B"));
 
-    Node* recv1 = RecvAtHost(
-        {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv1 =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+                   b2.opts().WithName("outside_compilation_F1_O1_recv"));
     Node* e = Unary(recv1, b2.opts().WithName("E"));
-    Node* send1 = SendFromHost({}, {},
+    Node* send1 = SendFromHost("host_compute_channel_F1_O1", {},
                                b2.opts()
                                    .WithName("outside_compilation_F1_O1_send")
                                    .WithControlInput(e));
@@ -1215,5 +1536,110 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
 }
 
+// Test for shape inference of outside compilation.
+TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
+  FunctionDefLibrary library;
+  GraphDef graphdef;
+
+  {
+    *library.add_function() = test::function::XTimesTwo();
+
+    GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
+    Node* a = InputShaped(b1.opts().WithName("A"));
+    Node* b = Input(b1.opts().WithName("B"));
+    // Give nodes 'c' and 'd' names that collide after lowercasing.
+    Node* c = Unary(a, b1.opts().WithName("C"));
+    Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
+                           "_encapsulate", "F1"));
+    Node* e = BinaryUnknownShape(c, d,
+                                 b1.opts()
+                                     .WithName("E")
+                                     .WithControlInputs({b, d})
+                                     .WithAttr("_encapsulate", "F1")
+                                     .WithAttr("_outside", "O1"));
+    Node* f = Binary(c, e,
+                     b1.opts().WithName("F").WithControlInput(e).WithAttr(
+                         "_encapsulate", "F1"));
+    Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
+    TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
+  }
+
+  TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+
+  FunctionDefLibrary library_expected;
+  GraphDef graphdef_expected;
+
+  string shape_string_expected;
+  {
+    GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
+    Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0"));
+    Node* recv =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+                   shape.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
+    SendFromHost("host_compute_channel_F1_O1", {e},
+                 shape.opts().WithName("outside_compilation_F1_O1_send"));
+    GraphDef shape_graph;
+    TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
+    EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+  }
+
+  *library_expected.add_function() = test::function::XTimesTwo();
+  *library_expected.add_function() = FunctionDefHelper::Create(
+      "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {},
+      {
+          {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
+          {{"F"},
+           "BinaryTest",
+           {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
+           {},
+           {"outside_compilation_O1_host_compute"}},
+          {{"outside_compilation_O1_host_compute"},
+           "_XlaHostCompute",
+           {"c:o:0"},
+           {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+            {"key", "host_compute_channel_F1_O1"},
+            {"shape_inference_graph", shape_string_expected},
+            {"shapes", gtl::ArraySlice<DataType>({})}},
+           {"c"}},
+      },
+      {{"f_0_retval", "F:o:0"}});
+
+  {
+    std::unique_ptr<FunctionLibraryDefinition> lib_def(
+        new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
+    GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
+    Node* a = InputShaped(b2.opts().WithName("A"));
+    Node* b = Input(b2.opts().WithName("B"));
+    Node* c = Unary(a, b2.opts().WithName("C"));
+
+    NodeBuilder node_builder("F1", "F1", lib_def.get());
+    node_builder.Input(b).Input(c);
+    Node* call =
+        b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder);
+
+    Node* recv =
+        RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
+                   b2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* e = BinaryUnknownShape(
+        c, ops::NodeOut(recv, 0),
+        b2.opts().WithName("E").WithControlInputs({recv, b}));
+    Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
+                              b2.opts()
+                                  .WithName("outside_compilation_F1_O1_send")
+                                  .WithControlInput(e));
+
+    Node* s = NoOp(
+        b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
+
+    Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
+    TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
+  }
+
+  TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
+  TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
+}
+
 }  // namespace
 }  // namespace tensorflow
index d6b5761..eae8e6c 100644 (file)
@@ -1064,26 +1064,36 @@ Status FunctionLibraryDefinition::AddLibrary(
   return Status::OK();
 }
 
-void FunctionLibraryDefinition::RemoveFunction(const string& func) {
+Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
   const auto& i = function_defs_.find(func);
-  DCHECK(i != function_defs_.end());
+  if (i == function_defs_.end()) {
+    return errors::InvalidArgument("Tried to remove non-existent function ",
+                                   func);
+  }
   function_defs_.erase(i);
+  return Status::OK();
 }
 
-void FunctionLibraryDefinition::RemoveGradient(const string& func) {
+Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
   const auto& i = func_grad_.find(func);
-  DCHECK(i != func_grad_.end());
+  if (i == func_grad_.end()) {
+    return errors::InvalidArgument("Tried to remove non-existent gradient ",
+                                   func);
+  }
   func_grad_.erase(i);
+  return Status::OK();
 }
 
 void FunctionLibraryDefinition::Remove(
     const std::vector<string>& funcs,
     const std::vector<string>& funcs_with_grads) {
   for (const string& f : funcs) {
-    RemoveFunction(f);
+    Status s = RemoveFunction(f);
+    DCHECK(s.ok());
   }
   for (const string& f : funcs_with_grads) {
-    RemoveGradient(f);
+    Status s = RemoveGradient(f);
+    DCHECK(s.ok());
   }
 }
 
index b933ee0..7d0e156 100644 (file)
@@ -312,6 +312,14 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
   // This operation is atomic.
   Status AddGradientDef(const GradientDef& grad);
 
+  // Remove function `func` from the library. Returns non-OK Status unless
+  // `func` is in the library.
+  Status RemoveFunction(const string& func);
+
+  // Remove gradient of function `func` from the library. Returns non-OK Status
+  // unless `func` has a gradient.
+  Status RemoveGradient(const string& func);
+
   // Adds the functions and gradients in 'other' to this function library.
   // Duplicate functions and gradients are ignored.
   // This operation is atomic.
@@ -384,13 +392,6 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
   // attr from.
   const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
 
-  // Remove function `func` from the library. `func` must be in the library.
-  void RemoveFunction(const string& func);
-
-  // Remove gradient of function `func` from the library. `func` must have
-  // a gradient.
-  void RemoveGradient(const string& func);
-
   // Remove all functions in `funcs` and all gradients of
   // functions in `funcs_with_grads` from this library.
   void Remove(const std::vector<string>& funcs,