// Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out.
Status AddOutsideCompilationHostIONodes(
- const string& subgraph_name,
+ const string& group_attribute, const string& subgraph_name,
+ const string& outside_compilation_attribute,
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Builds a _RecvAtHost node producing all the inputs of an
// outside_compilation subgraph and stores it in oc_subgraph.recv_at_host.
- Status AddRecvAtHostNode(const string& subgraph_name,
+ Status AddRecvAtHostNode(const string& group_attribute,
+ const string& subgraph_name,
+ const string& outside_compilation_attribute,
const string& oc_subgraph_name,
OutsideCompilationSubgraph* oc_subgraph,
Graph* graph_out);
// outside_compilation subgraph and stores it in oc_subgraph.send_from_host.
Status AddSendFromHostNode(
const std::unordered_map<const Node*, Node*>& node_images,
- const string& subgraph_name, const string& oc_subgraph_name,
- OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out);
+ const string& group_attribute, const string& subgraph_name,
+ const string& outside_compilation_attribute,
+ const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph,
+ Graph* graph_out);
// The subgraph extracted from the input graph, suitable for being turned
// into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
}
Status Encapsulator::Subgraph::AddRecvAtHostNode(
- const string& subgraph_name, const string& oc_subgraph_name,
+ const string& group_attribute, const string& subgraph_name,
+ const string& outside_compilation_attribute, const string& oc_subgraph_name,
OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
if (host_compute_key_placeholder_ == nullptr) {
TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out));
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_recv"),
kRecvAtHostOp);
- // TODO(misard) When we add replication the device placement will have to be
- // redone.
builder.Device(device_);
builder.Attr("Toutputs", dtypes);
- // TODO(misard) For now we only support TPU device 0.
+ // The correct device_ordinal will be inserted during replication in a
+ // subsequent rewrite.
builder.Attr("device_ordinal", 0);
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
"_", oc_subgraph_name));
+ builder.Attr(group_attribute, subgraph_name);
+ builder.Attr(outside_compilation_attribute, oc_subgraph_name);
builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING);
Status s = builder.Finalize(&recv_def);
if (!s.ok()) return s;
Status Encapsulator::Subgraph::AddSendFromHostNode(
const std::unordered_map<const Node*, Node*>& node_images,
- const string& subgraph_name, const string& oc_subgraph_name,
+ const string& group_attribute, const string& subgraph_name,
+ const string& outside_compilation_attribute, const string& oc_subgraph_name,
OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
if (host_compute_key_placeholder_ == nullptr) {
TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out));
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_send"),
kSendFromHostOp);
- // TODO(misard) When we add replication the device placement will have to be
- // redone.
builder.Device(device_);
builder.Attr("Tinputs", dtypes);
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
"_", oc_subgraph_name));
- // TODO(misard) For now we only support TPU device 0.
+ // The correct device_ordinal will be inserted during replication in a
+ // subsequent rewrite.
builder.Attr("device_ordinal", 0);
+ builder.Attr(group_attribute, subgraph_name);
+ builder.Attr(outside_compilation_attribute, oc_subgraph_name);
builder.Input(inputs);
builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING);
Status s = builder.Finalize(&send_def);
}
Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
- const string& subgraph_name,
+ const string& group_attribute, const string& subgraph_name,
+ const string& outside_compilation_attribute,
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out) {
for (auto& outside_compilation_subgraph_entry :
outside_compilation_subgraph_entry.second;
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
- TF_RETURN_IF_ERROR(
- AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out));
+ TF_RETURN_IF_ERROR(AddRecvAtHostNode(group_attribute, subgraph_name,
+ outside_compilation_attribute,
+ oc_name, &oc_subgraph, graph_out));
}
if (!oc_subgraph.outputs_by_src.empty() ||
!oc_subgraph.control_outputs.empty()) {
- TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name,
- oc_name, &oc_subgraph, graph_out));
+ TF_RETURN_IF_ERROR(AddSendFromHostNode(
+ node_images, group_attribute, subgraph_name,
+ outside_compilation_attribute, oc_name, &oc_subgraph, graph_out));
}
}
return Status::OK();
"Parallel checking is not supported when outside_compilation "
"clusters are present.");
}
- image->ClearAttr(group_attribute_);
- image->ClearAttr(outside_compilation_attribute_);
}
(*node_images)[node] = image;
}
const string& subgraph_name = subgraph_entry.first;
Subgraph& subgraph = subgraph_entry.second;
TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes(
- subgraph_name, node_images, graph_out));
+ group_attribute_, subgraph_name, outside_compilation_attribute_,
+ node_images, graph_out));
}
return Status::OK();
}
.FinalizeBuilder(&node_builder);
}
-Node* RecvAtHost(ops::NodeOut key_input, const string& key,
+Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
+ const string& oc_cluster,
const gtl::ArraySlice<DataType>& dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
- NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
+ string key =
+ strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+ string name = strings::StrCat("outside_compilation_", cluster, "_",
+ oc_cluster, "_recv");
+ NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
"_XlaRecvAtHost", opts.op_registry());
node_builder.Input(std::move(key_input));
return opts.WithAttr("Toutputs", dtypes)
.WithAttr("key", key)
.WithAttr("device_ordinal", 0)
+ .WithAttr("_encapsulate", cluster)
+ .WithAttr("_outside", oc_cluster)
.FinalizeBuilder(&node_builder);
}
-Node* SendFromHost(ops::NodeOut key_input, const string& key,
+Node* SendFromHost(ops::NodeOut key_input, const string& cluster,
+ const string& oc_cluster,
const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
- NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
+ string key =
+ strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+ string name = strings::StrCat("outside_compilation_", cluster, "_",
+ oc_cluster, "_send");
+ NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
"_XlaSendFromHost", opts.op_registry());
node_builder.Input(inputs);
node_builder.Input(std::move(key_input));
return opts.WithAttr("Tinputs", dtypes)
.WithAttr("key", key)
.WithAttr("device_ordinal", 0)
+ .WithAttr("_encapsulate", cluster)
+ .WithAttr("_outside", oc_cluster)
.FinalizeBuilder(&node_builder);
}
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
- Node* recv =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {DT_FLOAT, DT_FLOAT},
- shape.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT, DT_FLOAT}, shape.opts());
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
- shape.opts().WithName("E"));
- SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
+ shape.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {DT_FLOAT, DT_FLOAT},
- b2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT, DT_FLOAT}, b2.opts());
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
- b2.opts().WithName("E").WithControlInputs({recv, b}));
- Node* send = SendFromHost(ops::NodeOut(key_constant, 0),
- "host_compute_channel_F1_O1", {e},
- b2.opts()
- .WithName("outside_compilation_F1_O1_send")
- .WithControlInput(e));
+ b2.opts()
+ .WithName("E")
+ .WithControlInputs({recv, b})
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithControlInput(e));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0"));
- Node* recv =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {DT_FLOAT, DT_FLOAT},
- shape1.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT, DT_FLOAT}, shape1.opts());
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
- shape1.opts().WithName("E"));
- SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {e}, shape1.opts().WithName("outside_compilation_F1_O1_send"));
+ shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0"));
- Node* recv1 =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {DT_FLOAT, DT_FLOAT},
- shape2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT, DT_FLOAT}, shape2.opts());
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
- shape2.opts().WithName("E"));
- Node* recv2 =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
- {DT_FLOAT, DT_FLOAT},
- shape2.opts().WithName("outside_compilation_F1_O2_recv"));
- Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
- SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
- {h}, shape2.opts().WithName("outside_compilation_F1_O2_send"));
+ shape2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
+ {DT_FLOAT, DT_FLOAT}, shape2.opts());
+ Node* h = Binary(ops::NodeOut(recv2, 0), e,
+ shape2.opts()
+ .WithName("H")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
}
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {DT_FLOAT, DT_FLOAT},
- b2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT, DT_FLOAT}, b2.opts());
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
- b2.opts().WithName("E").WithControlInputs({recv1, b}));
- Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0),
- "host_compute_channel_F1_O1", {e},
- b2.opts()
- .WithName("outside_compilation_F1_O1_send")
- .WithControlInput(e));
-
- Node* recv2 =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
- {DT_FLOAT, DT_FLOAT},
- b2.opts().WithName("outside_compilation_F1_O2_recv"));
+ b2.opts()
+ .WithName("E")
+ .WithControlInputs({recv1, b})
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithControlInput(e));
+
+ Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
+ {DT_FLOAT, DT_FLOAT}, b2.opts());
Node* g = Binary(e, ops::NodeOut(recv2, 1),
- b2.opts().WithName("G").WithControlInputs({recv2, e}));
- Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
- Node* send2 = SendFromHost(
- ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", {h},
- b2.opts().WithName("outside_compilation_F1_O2_send"));
+ b2.opts()
+ .WithName("G")
+ .WithControlInputs({recv2, e})
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2"));
+ Node* h = Binary(ops::NodeOut(recv2, 0), e,
+ b2.opts()
+ .WithName("H")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2"));
+ Node* send2 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts());
Node* s = Sequencer(b2.opts()
.WithName("F1_sequencer")
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
- Node* recv =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {DT_FLOAT, DT_FLOAT},
- shape.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT, DT_FLOAT}, shape.opts());
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
- shape.opts().WithName("E"));
- SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
+ shape.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
Node* key_constant1 =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 =
- RecvAtHost(ops::NodeOut(key_constant1, 0), "host_compute_channel_F1_O1",
- {DT_FLOAT, DT_FLOAT},
- b2.opts().WithName("outside_compilation_F1_O1_recv"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1",
+ {DT_FLOAT, DT_FLOAT}, b2.opts());
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
- b2.opts().WithName("E").WithControlInputs({recv1, b}));
- Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0),
- "host_compute_channel_F1_O1", {e},
- b2.opts()
- .WithName("outside_compilation_F1_O1_send")
- .WithControlInput(e));
+ b2.opts()
+ .WithName("E")
+ .WithControlInputs({recv1, b})
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e},
+ b2.opts().WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
- Node* recv2 = RecvAtHost(
- ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1",
- {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
- Node* h = Binary(ops::NodeOut(call1, 1), recv2, b2.opts().WithName("H"));
- Node* send2 = SendFromHost(
- ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1", {h},
- b2.opts().WithName("outside_compilation_F2_O1_send"));
+ Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* h = Binary(ops::NodeOut(call1, 1), recv2,
+ b2.opts()
+ .WithName("H")
+ .WithAttr("_encapsulate", "F2")
+ .WithAttr("_outside", "O1"));
+ Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
+ b2.opts());
Node* s2 = Sequencer(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* e = Unary(a, b2.opts().WithName("E"));
+ Node* e = Unary(a, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* send1 = SendFromHost(
- ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e},
- b2.opts().WithName("outside_compilation_F1_O1_send"));
+ Node* send1 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 =
- RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
- Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
- Node* send1 = SendFromHost(
- ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e},
- b2.opts().WithName("outside_compilation_F1_O1_send"));
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts());
+ Node* e = Unary(a, b2.opts()
+ .WithName("E")
+ .WithControlInput(recv1)
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send1 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(
- ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
- b2.opts().WithName("outside_compilation_F1_O1_recv"));
- Node* e = Unary(recv1, b2.opts().WithName("E"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(
- ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
- b2.opts().WithName("outside_compilation_F1_O1_recv"));
- Node* e = Unary(recv1, b2.opts().WithName("E"));
- Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0),
- "host_compute_channel_F1_O1", {},
- b2.opts()
- .WithName("outside_compilation_F1_O1_send")
- .WithControlInput(e));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {},
+ b2.opts().WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* e = Unary(a, b2.opts().WithName("E"));
+ Node* e = Unary(a, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* key_constant =
KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1"));
- Node* recv = RecvAtHost(
- ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
- shape.opts().WithName("outside_compilation_F1_O1_recv"));
- Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
- SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
- {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
+ Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, shape.opts());
+ Node* e = BinaryUnknownShape(known, recv,
+ shape.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(
- ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT},
- b2.opts().WithName("outside_compilation_F1_O1_recv"));
- Node* e = BinaryUnknownShape(
- c, ops::NodeOut(recv, 0),
- b2.opts().WithName("E").WithControlInputs({recv, b}));
- Node* send = SendFromHost(ops::NodeOut(key_constant, 0),
- "host_compute_channel_F1_O1", {e},
- b2.opts()
- .WithName("outside_compilation_F1_O1_send")
- .WithControlInput(e));
+ Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0),
+ b2.opts()
+ .WithName("E")
+ .WithControlInputs({recv, b})
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithControlInput(e));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),