void ConnectSequencerToCallNode(Graph* graph_out);
Status AddShapeInferenceInfo(
+ const string& subgraph_name,
const string& outside_compilation_subgraph_name,
- const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
+ const std::vector<TensorShapeProto>& shapes, Graph* inference_graph,
+ FunctionLibraryDefinition* library);
Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
- std::unique_ptr<GraphDef>* graphdef_out);
+ std::unique_ptr<Graph>* graph_out);
// Makes a copy of graph containing only nodes that are ancestors of at least
// one node in send_from_host_nodes and store it in pruned_graph. On exit
}
Status Encapsulator::Subgraph::AddShapeInferenceInfo(
+ const string& subgraph_name,
const string& outside_compilation_subgraph_name,
- const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
+ const std::vector<TensorShapeProto>& shapes, Graph* inference_graph,
+ FunctionLibraryDefinition* library) {
OutsideCompilationSubgraph& oc_subgraph =
outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
host_compute->AddAttr("shape_inference_graph", "");
host_compute->AddAttr("shapes", shapes);
} else {
- string serialized_graph;
- if (!inference_graph->SerializeToString(&serialized_graph)) {
- return errors::Internal(
- "Failed to serialize graph for outside compilation subgraph ",
- oc_subgraph.host_compute_name);
- }
- host_compute->AddAttr("shape_inference_graph", serialized_graph);
+ string inference_graph_name =
+ strings::StrCat("_outside_compilation_shape_inference_", subgraph_name,
+ "_", outside_compilation_subgraph_name);
+ FunctionDef fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef));
+ host_compute->AddAttr("shape_inference_graph", inference_graph_name);
host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
}
return Status::OK();
}
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
- std::unique_ptr<GraphDef>* graphdef_out) {
+ std::unique_ptr<Graph>* graph_out) {
// Maps from nodes in graph_in to nodes in graph_out.
//
// When an edge has fully defined shape the source node in graph_in is
std::unordered_map<Node*, Node*> dummy_node_images;
std::unordered_map<Node*, Node*> copied_node_images;
- std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
- graph_out->set_versions(graph_in.versions());
+ graph_out->reset(new Graph(graph_in.op_registry()));
+ (*graph_out)->set_versions(graph_in.versions());
// The final input to the send node is the dynamic key, which we don't include
// in the static shapes.
static_shape_out->resize(send_node->num_inputs() - 1);
if (w.leave) {
TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
n, send_node, dummy_node_images, library, &copied_node_images,
- graph_out.get()));
+ graph_out->get()));
} else {
if (visited[n->id()]) continue;
visited[n->id()] = true;
context->ShapeHandleToProto(shape, &proto);
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
dummy_node_images[src_node] = AddDummyShapedNode(
- src_node->output_type(src_port), proto, graph_out.get());
+ src_node->output_type(src_port), proto, graph_out->get());
}
// The final input to the send node is the dynamic key, which we
// don't include in the static shapes.
// The shapes of all the inputs to send_node are statically known. We
// won't have to do any inference at compile time so return now: the
// shapes were stored in static_shape_out above.
- graphdef_out->reset();
+ graph_out->reset();
return Status::OK();
} else {
// Any shape that is being processed is either the original send node
}
}
- graphdef_out->reset(new GraphDef());
- graph_out->ToGraphDef(graphdef_out->get());
-
return Status::OK();
}
}
for (auto& subgraph_entry : subgraphs_) {
+ const string& subgraph_name = subgraph_entry.first;
Subgraph& subgraph = subgraph_entry.second;
// Find all the recv_at_host nodes in this subgraph.
std::vector<string> outside_compilation_names;
subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
std::unordered_set<string> recv_at_host_names;
- for (const auto& name : outside_compilation_names) {
- Node* recv_node = subgraph.GetRecvAtHostNode(name);
+ for (const auto& oc_name : outside_compilation_names) {
+ Node* recv_node = subgraph.GetRecvAtHostNode(oc_name);
if (recv_node != nullptr) {
recv_at_host_names.insert(recv_node->name());
}
// without knowing the shape of the recv_at_host nodes, and store the
// result, along with enough information to complete the job at compile time
// once the recv_at_host shapes are known.
- for (const auto& name : outside_compilation_names) {
- Node* send_node = subgraph.GetSendFromHostNode(name);
+ for (const auto& oc_name : outside_compilation_names) {
+ Node* send_node = subgraph.GetSendFromHostNode(oc_name);
std::vector<TensorShapeProto> static_shape;
- std::unique_ptr<GraphDef> graphdef;
+ std::unique_ptr<Graph> graph;
if (send_node != nullptr) {
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
*pruned_graph, shape_refiner, recv_at_host_names,
- node_images[send_node], library, &static_shape, &graphdef));
- if (graphdef == nullptr) {
+ node_images[send_node], library, &static_shape, &graph));
+ if (graph == nullptr) {
VLOG(2) << "Send node " << send_node->name() << " shapes";
for (int i = 0; i < static_shape.size(); ++i) {
VLOG(2) << static_shape[i].DebugString();
}
} else {
- VLOG(2) << "Send node " << send_node->name() << " graph\n"
- << graphdef->DebugString();
+ if (VLOG_IS_ON(2)) {
+ GraphDef graphdef;
+ graph->ToGraphDef(&graphdef);
+ VLOG(2) << "Send node " << send_node->name() << " graph\n"
+ << graphdef.DebugString();
+ }
}
}
- TF_RETURN_IF_ERROR(
- subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
+ TF_RETURN_IF_ERROR(subgraph.AddShapeInferenceInfo(
+ subgraph_name, oc_name, static_shape, graph.get(), library));
}
if (!outside_compilation_names.empty()) {
TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
limitations under the License.
==============================================================================*/
+#include <memory>
#include <utility>
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
+Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder,
+ const string& name_suffix,
+ FunctionDefLibrary* library) {
+ GraphDef graphdef;
+ TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef));
+ std::unique_ptr<Graph> graph =
+ std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
+ GraphConstructorOptions opts;
+ opts.allow_internal_ops = true;
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get()));
+ FunctionDef* fdef = library->add_function();
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *graph,
+ strings::StrCat("_outside_compilation_shape_inference_", name_suffix),
+ fdef));
+ return Status::OK();
+}
+
template <class Tkey, class Tvalue>
bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
a.attr(), b.attr(), [](const string& s) { return s; },
[](const AttrValue& v) { return v.DebugString(); },
[](const string& key, const AttrValue& av, const AttrValue& bv) {
- if (key == "shape_inference_graph") {
- // Default serialization of GraphDef is unstable because maps don't
- // serialize deterministically. Rather than go through the hoops to
- // turn on deterministic serialization of this attr just for this
- // test, add logic here to compare determinstically.
- GraphDef ga;
- if (!ga.ParseFromString(av.s())) {
- return false;
- }
- GraphDef gb;
- if (!gb.ParseFromString(bv.s())) {
- return false;
- }
- return EqualGraphDef(ga, gb, nullptr);
- } else {
- return av.DebugString() == bv.DebugString();
- }
+ return av.DebugString() == bv.DebugString();
},
strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
diff);
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
- string shape_string_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant =
shape.opts().WithName("E"));
SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
{e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
- GraphDef shape_graph;
- TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
- EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
*library_expected.add_function() = test::function::XTimesTwo();
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", shape_string_expected},
+ {"shape_inference_graph",
+ "_outside_compilation_shape_inference_F1_O1"},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"c"}},
},
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
- string shape_string_expected_1;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant =
shape1.opts().WithName("E"));
SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
{e}, shape1.opts().WithName("outside_compilation_F1_O1_send"));
- GraphDef shape1_graph;
- TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
- EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
- string shape_string_expected_2;
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant =
Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
{h}, shape2.opts().WithName("outside_compilation_F1_O2_send"));
- GraphDef shape2_graph;
- TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
- EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
}
*library_expected.add_function() = FunctionDefHelper::Create(
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O2"},
- {"shape_inference_graph", shape_string_expected_2},
+ {"shape_inference_graph",
+ "_outside_compilation_shape_inference_F1_O2"},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"F"}},
{{"outside_compilation_O1_host_compute"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", shape_string_expected_1},
+ {"shape_inference_graph",
+ "_outside_compilation_shape_inference_F1_O1"},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"D"}},
},
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
- string shape_string_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant =
shape.opts().WithName("E"));
SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
{e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
- GraphDef shape_graph;
- TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
- EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
TensorShapeProto shape_proto_expected;
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", shape_string_expected},
+ {"shape_inference_graph",
+ "_outside_compilation_shape_inference_F1_O1"},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"D"}},
},
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
- string shape_string_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant =
Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
{e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
- GraphDef shape_graph;
- TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
- EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
*library_expected.add_function() = test::function::XTimesTwo();
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", shape_string_expected},
+ {"shape_inference_graph",
+ "_outside_compilation_shape_inference_F1_O1"},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"c"}},
},