Leaves attributes on outside_compilation nodes so they can be replicated in a later...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 29 Mar 2018 19:02:50 +0000 (12:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 19:05:12 +0000 (12:05 -0700)
PiperOrigin-RevId: 190965218

tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
tensorflow/contrib/tpu/ops/replication_ops.cc

index 7fc43fb..53ec6c1 100644 (file)
@@ -254,7 +254,8 @@ class Encapsulator {
 
     // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out.
     Status AddOutsideCompilationHostIONodes(
-        const string& subgraph_name,
+        const string& group_attribute, const string& subgraph_name,
+        const string& outside_compilation_attribute,
         const std::unordered_map<const Node*, Node*>& node_images,
         Graph* graph_out);
 
@@ -405,7 +406,9 @@ class Encapsulator {
 
     // Builds a _RecvAtHost node producing all the inputs of an
     // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host.
-    Status AddRecvAtHostNode(const string& subgraph_name,
+    Status AddRecvAtHostNode(const string& group_attribute,
+                             const string& subgraph_name,
+                             const string& outside_compilation_attribute,
                              const string& oc_subgraph_name,
                              OutsideCompilationSubgraph* oc_subgraph,
                              Graph* graph_out);
@@ -414,8 +417,10 @@ class Encapsulator {
     // outside_compilation subgraph and stores it in oc_subgraph.send_from_host.
     Status AddSendFromHostNode(
         const std::unordered_map<const Node*, Node*>& node_images,
-        const string& subgraph_name, const string& oc_subgraph_name,
-        OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out);
+        const string& group_attribute, const string& subgraph_name,
+        const string& outside_compilation_attribute,
+        const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph,
+        Graph* graph_out);
 
     // The subgraph extracted from the input graph, suitable for being turned
     // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
@@ -1114,7 +1119,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder(
 }
 
 Status Encapsulator::Subgraph::AddRecvAtHostNode(
-    const string& subgraph_name, const string& oc_subgraph_name,
+    const string& group_attribute, const string& subgraph_name,
+    const string& outside_compilation_attribute, const string& oc_subgraph_name,
     OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
   if (host_compute_key_placeholder_ == nullptr) {
     TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out));
@@ -1135,14 +1141,15 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
   NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
                                          "_", oc_subgraph_name, "_recv"),
                          kRecvAtHostOp);
-  // TODO(misard) When we add replication the device placement will have to be
-  // redone.
   builder.Device(device_);
   builder.Attr("Toutputs", dtypes);
-  // TODO(misard) For now we only support TPU device 0.
+  // The correct device_ordinal will be inserted during replication in a
+  // subsequent rewrite.
   builder.Attr("device_ordinal", 0);
   builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
                                       "_", oc_subgraph_name));
+  builder.Attr(group_attribute, subgraph_name);
+  builder.Attr(outside_compilation_attribute, oc_subgraph_name);
   builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING);
   Status s = builder.Finalize(&recv_def);
   if (!s.ok()) return s;
@@ -1163,7 +1170,8 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
 
 Status Encapsulator::Subgraph::AddSendFromHostNode(
     const std::unordered_map<const Node*, Node*>& node_images,
-    const string& subgraph_name, const string& oc_subgraph_name,
+    const string& group_attribute, const string& subgraph_name,
+    const string& outside_compilation_attribute, const string& oc_subgraph_name,
     OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
   if (host_compute_key_placeholder_ == nullptr) {
     TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out));
@@ -1188,14 +1196,15 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
   NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
                                          "_", oc_subgraph_name, "_send"),
                          kSendFromHostOp);
-  // TODO(misard) When we add replication the device placement will have to be
-  // redone.
   builder.Device(device_);
   builder.Attr("Tinputs", dtypes);
   builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
                                       "_", oc_subgraph_name));
-  // TODO(misard) For now we only support TPU device 0.
+  // The correct device_ordinal will be inserted during replication in a
+  // subsequent rewrite.
   builder.Attr("device_ordinal", 0);
+  builder.Attr(group_attribute, subgraph_name);
+  builder.Attr(outside_compilation_attribute, oc_subgraph_name);
   builder.Input(inputs);
   builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING);
   Status s = builder.Finalize(&send_def);
@@ -1216,7 +1225,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
 }
 
 Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
-    const string& subgraph_name,
+    const string& group_attribute, const string& subgraph_name,
+    const string& outside_compilation_attribute,
     const std::unordered_map<const Node*, Node*>& node_images,
     Graph* graph_out) {
   for (auto& outside_compilation_subgraph_entry :
@@ -1226,14 +1236,16 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
         outside_compilation_subgraph_entry.second;
 
     if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
-      TF_RETURN_IF_ERROR(
-          AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out));
+      TF_RETURN_IF_ERROR(AddRecvAtHostNode(group_attribute, subgraph_name,
+                                           outside_compilation_attribute,
+                                           oc_name, &oc_subgraph, graph_out));
     }
 
     if (!oc_subgraph.outputs_by_src.empty() ||
         !oc_subgraph.control_outputs.empty()) {
-      TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name,
-                                             oc_name, &oc_subgraph, graph_out));
+      TF_RETURN_IF_ERROR(AddSendFromHostNode(
+          node_images, group_attribute, subgraph_name,
+          outside_compilation_attribute, oc_name, &oc_subgraph, graph_out));
     }
   }
   return Status::OK();
@@ -1450,8 +1462,6 @@ Status Encapsulator::CopyNodesToOutputGraph(
             "Parallel checking is not supported when outside_compilation "
             "clusters are present.");
       }
-      image->ClearAttr(group_attribute_);
-      image->ClearAttr(outside_compilation_attribute_);
     }
     (*node_images)[node] = image;
   }
@@ -1477,7 +1487,8 @@ Status Encapsulator::AddOutsideCompilationHostIONodes(
     const string& subgraph_name = subgraph_entry.first;
     Subgraph& subgraph = subgraph_entry.second;
     TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes(
-        subgraph_name, node_images, graph_out));
+        group_attribute_, subgraph_name, outside_compilation_attribute_,
+        node_images, graph_out));
   }
   return Status::OK();
 }
index 94481a1..7899b5d 100644 (file)
@@ -382,24 +382,36 @@ Node* KeyPlaceholder(const string& call_node,
       .FinalizeBuilder(&node_builder);
 }
 
-Node* RecvAtHost(ops::NodeOut key_input, const string& key,
+Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
+                 const string& oc_cluster,
                  const gtl::ArraySlice<DataType>& dtypes,
                  const GraphDefBuilder::Options& opts) {
   if (opts.HaveError()) return nullptr;
-  NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
+  string key =
+      strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+  string name = strings::StrCat("outside_compilation_", cluster, "_",
+                                oc_cluster, "_recv");
+  NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
                            "_XlaRecvAtHost", opts.op_registry());
   node_builder.Input(std::move(key_input));
   return opts.WithAttr("Toutputs", dtypes)
       .WithAttr("key", key)
       .WithAttr("device_ordinal", 0)
+      .WithAttr("_encapsulate", cluster)
+      .WithAttr("_outside", oc_cluster)
       .FinalizeBuilder(&node_builder);
 }
 
-Node* SendFromHost(ops::NodeOut key_input, const string& key,
+Node* SendFromHost(ops::NodeOut key_input, const string& cluster,
+                   const string& oc_cluster,
                    const std::vector<ops::NodeOut>& inputs,
                    const GraphDefBuilder::Options& opts) {
   if (opts.HaveError()) return nullptr;
-  NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
+  string key =
+      strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+  string name = strings::StrCat("outside_compilation_", cluster, "_",
+                                oc_cluster, "_send");
+  NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
                            "_XlaSendFromHost", opts.op_registry());
   node_builder.Input(inputs);
   node_builder.Input(std::move(key_input));
@@ -410,6 +422,8 @@ Node* SendFromHost(ops::NodeOut key_input, const string& key,
   return opts.WithAttr("Tinputs", dtypes)
       .WithAttr("key", key)
       .WithAttr("device_ordinal", 0)
+      .WithAttr("_encapsulate", cluster)
+      .WithAttr("_outside", oc_cluster)
       .FinalizeBuilder(&node_builder);
 }
 
@@ -856,14 +870,14 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
         KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
-    Node* recv =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                   {DT_FLOAT, DT_FLOAT},
-                   shape.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                            {DT_FLOAT, DT_FLOAT}, shape.opts());
     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
-                     shape.opts().WithName("E"));
-    SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                 {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
+                     shape.opts()
+                         .WithName("E")
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O1"));
+    SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
     TF_EXPECT_OK(
         AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
   }
@@ -901,17 +915,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
 
     Node* key_constant =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
-    Node* recv =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                   {DT_FLOAT, DT_FLOAT},
-                   b2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                            {DT_FLOAT, DT_FLOAT}, b2.opts());
     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
-                     b2.opts().WithName("E").WithControlInputs({recv, b}));
-    Node* send = SendFromHost(ops::NodeOut(key_constant, 0),
-                              "host_compute_channel_F1_O1", {e},
-                              b2.opts()
-                                  .WithName("outside_compilation_F1_O1_send")
-                                  .WithControlInput(e));
+                     b2.opts()
+                         .WithName("E")
+                         .WithControlInputs({recv, b})
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O1"));
+    Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+                              b2.opts().WithControlInput(e));
 
     Node* s = Sequencer(
         b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
@@ -976,14 +989,14 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
         KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0"));
-    Node* recv =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                   {DT_FLOAT, DT_FLOAT},
-                   shape1.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                            {DT_FLOAT, DT_FLOAT}, shape1.opts());
     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
-                     shape1.opts().WithName("E"));
-    SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                 {e}, shape1.opts().WithName("outside_compilation_F1_O1_send"));
+                     shape1.opts()
+                         .WithName("E")
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O1"));
+    SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
     TF_EXPECT_OK(
         AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
   }
@@ -992,19 +1005,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
     GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
         KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0"));
-    Node* recv1 =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                   {DT_FLOAT, DT_FLOAT},
-                   shape2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                             {DT_FLOAT, DT_FLOAT}, shape2.opts());
     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
-                     shape2.opts().WithName("E"));
-    Node* recv2 =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
-                   {DT_FLOAT, DT_FLOAT},
-                   shape2.opts().WithName("outside_compilation_F1_O2_recv"));
-    Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
-    SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
-                 {h}, shape2.opts().WithName("outside_compilation_F1_O2_send"));
+                     shape2.opts()
+                         .WithName("E")
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O1"));
+    Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
+                             {DT_FLOAT, DT_FLOAT}, shape2.opts());
+    Node* h = Binary(ops::NodeOut(recv2, 0), e,
+                     shape2.opts()
+                         .WithName("H")
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O2"));
+    SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts());
     TF_EXPECT_OK(
         AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
   }
@@ -1054,28 +1069,32 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
 
     Node* key_constant =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
-    Node* recv1 =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                   {DT_FLOAT, DT_FLOAT},
-                   b2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                             {DT_FLOAT, DT_FLOAT}, b2.opts());
     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
-                     b2.opts().WithName("E").WithControlInputs({recv1, b}));
-    Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0),
-                               "host_compute_channel_F1_O1", {e},
-                               b2.opts()
-                                   .WithName("outside_compilation_F1_O1_send")
-                                   .WithControlInput(e));
-
-    Node* recv2 =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
-                   {DT_FLOAT, DT_FLOAT},
-                   b2.opts().WithName("outside_compilation_F1_O2_recv"));
+                     b2.opts()
+                         .WithName("E")
+                         .WithControlInputs({recv1, b})
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O1"));
+    Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+                               b2.opts().WithControlInput(e));
+
+    Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
+                             {DT_FLOAT, DT_FLOAT}, b2.opts());
     Node* g = Binary(e, ops::NodeOut(recv2, 1),
-                     b2.opts().WithName("G").WithControlInputs({recv2, e}));
-    Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
-    Node* send2 = SendFromHost(
-        ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", {h},
-        b2.opts().WithName("outside_compilation_F1_O2_send"));
+                     b2.opts()
+                         .WithName("G")
+                         .WithControlInputs({recv2, e})
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O2"));
+    Node* h = Binary(ops::NodeOut(recv2, 0), e,
+                     b2.opts()
+                         .WithName("H")
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O2"));
+    Node* send2 =
+        SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts());
 
     Node* s = Sequencer(b2.opts()
                             .WithName("F1_sequencer")
@@ -1139,14 +1158,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
         KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
-    Node* recv =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                   {DT_FLOAT, DT_FLOAT},
-                   shape.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                            {DT_FLOAT, DT_FLOAT}, shape.opts());
     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
-                     shape.opts().WithName("E"));
-    SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                 {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
+                     shape.opts()
+                         .WithName("E")
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O1"));
+    SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
     TF_EXPECT_OK(
         AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
   }
@@ -1207,17 +1226,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
 
     Node* key_constant1 =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
-    Node* recv1 =
-        RecvAtHost(ops::NodeOut(key_constant1, 0), "host_compute_channel_F1_O1",
-                   {DT_FLOAT, DT_FLOAT},
-                   b2.opts().WithName("outside_compilation_F1_O1_recv"));
+    Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1",
+                             {DT_FLOAT, DT_FLOAT}, b2.opts());
     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
-                     b2.opts().WithName("E").WithControlInputs({recv1, b}));
-    Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0),
-                               "host_compute_channel_F1_O1", {e},
-                               b2.opts()
-                                   .WithName("outside_compilation_F1_O1_send")
-                                   .WithControlInput(e));
+                     b2.opts()
+                         .WithName("E")
+                         .WithControlInputs({recv1, b})
+                         .WithAttr("_encapsulate", "F1")
+                         .WithAttr("_outside", "O1"));
+    Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e},
+                               b2.opts().WithControlInput(e));
     Node* s1 = Sequencer(
         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
         "F1");
@@ -1229,13 +1247,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
 
     Node* key_constant2 =
         KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
-    Node* recv2 = RecvAtHost(
-        ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1",
-        {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
-    Node* h = Binary(ops::NodeOut(call1, 1), recv2, b2.opts().WithName("H"));
-    Node* send2 = SendFromHost(
-        ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1", {h},
-        b2.opts().WithName("outside_compilation_F2_O1_send"));
+    Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1",
+                             {DT_FLOAT}, b2.opts());
+    Node* h = Binary(ops::NodeOut(call1, 1), recv2,
+                     b2.opts()
+                         .WithName("H")
+                         .WithAttr("_encapsulate", "F2")
+                         .WithAttr("_outside", "O1"));
+    Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
+                               b2.opts());
 
     Node* s2 = Sequencer(
         b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
@@ -1311,12 +1331,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
     Node* a = InputShaped(b2.opts().WithName("A"));
     Node* b = Input(b2.opts().WithName("B"));
 
-    Node* e = Unary(a, b2.opts().WithName("E"));
+    Node* e = Unary(a, b2.opts()
+                           .WithName("E")
+                           .WithAttr("_encapsulate", "F1")
+                           .WithAttr("_outside", "O1"));
     Node* key_constant =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
-    Node* send1 = SendFromHost(
-        ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e},
-        b2.opts().WithName("outside_compilation_F1_O1_send"));
+    Node* send1 =
+        SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
     Node* s1 = Sequencer(
         b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1");
     NodeBuilder node_builder1("F1", "F1", lib_def.get());
@@ -1395,12 +1417,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
     Node* key_constant =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
     Node* recv1 =
-        RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                   {}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
-    Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
-    Node* send1 = SendFromHost(
-        ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e},
-        b2.opts().WithName("outside_compilation_F1_O1_send"));
+        RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts());
+    Node* e = Unary(a, b2.opts()
+                           .WithName("E")
+                           .WithControlInput(recv1)
+                           .WithAttr("_encapsulate", "F1")
+                           .WithAttr("_outside", "O1"));
+    Node* send1 =
+        SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
     Node* s1 = Sequencer(
         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
         "F1");
@@ -1470,10 +1494,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
 
     Node* key_constant =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
-    Node* recv1 = RecvAtHost(
-        ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
-        b2.opts().WithName("outside_compilation_F1_O1_recv"));
-    Node* e = Unary(recv1, b2.opts().WithName("E"));
+    Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                             {DT_FLOAT}, b2.opts());
+    Node* e = Unary(recv1, b2.opts()
+                               .WithName("E")
+                               .WithAttr("_encapsulate", "F1")
+                               .WithAttr("_outside", "O1"));
     Node* s1 = Sequencer(
         b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1");
     NodeBuilder node_builder1("F1", "F1", lib_def.get());
@@ -1547,15 +1573,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
 
     Node* key_constant =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
-    Node* recv1 = RecvAtHost(
-        ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
-        b2.opts().WithName("outside_compilation_F1_O1_recv"));
-    Node* e = Unary(recv1, b2.opts().WithName("E"));
-    Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0),
-                               "host_compute_channel_F1_O1", {},
-                               b2.opts()
-                                   .WithName("outside_compilation_F1_O1_send")
-                                   .WithControlInput(e));
+    Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                             {DT_FLOAT}, b2.opts());
+    Node* e = Unary(recv1, b2.opts()
+                               .WithName("E")
+                               .WithAttr("_encapsulate", "F1")
+                               .WithAttr("_outside", "O1"));
+    Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {},
+                               b2.opts().WithControlInput(e));
     Node* s1 = Sequencer(
         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
         "F1");
@@ -1615,7 +1640,10 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
     Node* a = Input(b2.opts().WithName("A"));
     Node* b = Input(b2.opts().WithName("B"));
 
-    Node* e = Unary(a, b2.opts().WithName("E"));
+    Node* e = Unary(a, b2.opts()
+                           .WithName("E")
+                           .WithAttr("_encapsulate", "F1")
+                           .WithAttr("_outside", "O1"));
     NodeBuilder node_builder1("F1", "F1", lib_def.get());
     node_builder1.Input(a).Input(b);
     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
@@ -1666,12 +1694,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
     Node* key_constant =
         KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
     Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1"));
-    Node* recv = RecvAtHost(
-        ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
-        shape.opts().WithName("outside_compilation_F1_O1_recv"));
-    Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
-    SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
-                 {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
+    Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                            {DT_FLOAT}, shape.opts());
+    Node* e = BinaryUnknownShape(known, recv,
+                                 shape.opts()
+                                     .WithName("E")
+                                     .WithAttr("_encapsulate", "F1")
+                                     .WithAttr("_outside", "O1"));
+    SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
     TF_EXPECT_OK(
         AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
   }
@@ -1709,17 +1739,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
 
     Node* key_constant =
         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
-    Node* recv = RecvAtHost(
-        ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
-        b2.opts().WithName("outside_compilation_F1_O1_recv"));
-    Node* e = BinaryUnknownShape(
-        c, ops::NodeOut(recv, 0),
-        b2.opts().WithName("E").WithControlInputs({recv, b}));
-    Node* send = SendFromHost(ops::NodeOut(key_constant, 0),
-                              "host_compute_channel_F1_O1", {e},
-                              b2.opts()
-                                  .WithName("outside_compilation_F1_O1_send")
-                                  .WithControlInput(e));
+    Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+                            {DT_FLOAT}, b2.opts());
+    Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0),
+                                 b2.opts()
+                                     .WithName("E")
+                                     .WithControlInputs({recv, b})
+                                     .WithAttr("_encapsulate", "F1")
+                                     .WithAttr("_outside", "O1"));
+    Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+                              b2.opts().WithControlInput(e));
 
     Node* s = Sequencer(
         b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
index cba71c6..3bdf7c2 100644 (file)
@@ -27,6 +27,7 @@ REGISTER_OP("TPUReplicateMetadata")
     .Attr("topology: string = \"\"")
     .Attr("device_assignment: list(int) = []")
     .Attr("computation_shape: list(int) = []")
+    .Attr("host_compute_core: list(string) = []")
     .SetShapeFn(shape_inference::UnknownShape);
 
 REGISTER_OP("TPUReplicatedInput")
@@ -68,6 +69,7 @@ REGISTER_OP("TPUReplicate")
     .Attr("num_replicas: int >= 1")
     .Attr("topology: string = \"\"")
     .Attr("device_assignment: list(int) = []")
+    .Attr("host_compute_core: list(string) = []")
     .Attr("computation_shape: list(int) = []")
     .Attr("Tinputs: list(type) >= 0")
     .Attr("Tbroadcast_inputs: list(type) >= 0")