Code cleanup: rather than storing the outside_compilation shape inference graph as...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Mar 2018 20:41:13 +0000 (13:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 20:43:21 +0000 (13:43 -0700)
PiperOrigin-RevId: 190118116

tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc

index 0475cd9..8e505da 100644 (file)
@@ -348,6 +348,7 @@ tf_cc_test(
     deps = [
         ":common",
         ":compilation_passes",
+        ":graph_to_functiondef",
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:cc_ops_internal",
         "//tensorflow/cc:function_ops",
index 0685036..7fc43fb 100644 (file)
@@ -334,8 +334,10 @@ class Encapsulator {
     void ConnectSequencerToCallNode(Graph* graph_out);
 
     Status AddShapeInferenceInfo(
+        const string& subgraph_name,
         const string& outside_compilation_subgraph_name,
-        const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
+        const std::vector<TensorShapeProto>& shapes, Graph* inference_graph,
+        FunctionLibraryDefinition* library);
 
     Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
 
@@ -573,7 +575,7 @@ class Encapsulator {
       const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
       FunctionLibraryDefinition* library,
       std::vector<TensorShapeProto>* static_shape_out,
-      std::unique_ptr<GraphDef>* graphdef_out);
+      std::unique_ptr<Graph>* graph_out);
 
   // Makes a copy of graph containing only nodes that are ancestors of at least
   // one node in send_from_host_nodes and store it in pruned_graph. On exit
@@ -949,8 +951,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
 }
 
 Status Encapsulator::Subgraph::AddShapeInferenceInfo(
+    const string& subgraph_name,
     const string& outside_compilation_subgraph_name,
-    const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
+    const std::vector<TensorShapeProto>& shapes, Graph* inference_graph,
+    FunctionLibraryDefinition* library) {
   OutsideCompilationSubgraph& oc_subgraph =
       outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
 
@@ -972,14 +976,15 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo(
     host_compute->AddAttr("shape_inference_graph", "");
     host_compute->AddAttr("shapes", shapes);
   } else {
-    string serialized_graph;
-    if (!inference_graph->SerializeToString(&serialized_graph)) {
-      return errors::Internal(
-          "Failed to serialize graph for outside compilation subgraph ",
-          oc_subgraph.host_compute_name);
-    }
-    host_compute->AddAttr("shape_inference_graph", serialized_graph);
+    string inference_graph_name =
+        strings::StrCat("_outside_compilation_shape_inference_", subgraph_name,
+                        "_", outside_compilation_subgraph_name);
+    FunctionDef fdef;
+    TF_RETURN_IF_ERROR(
+        GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef));
+    host_compute->AddAttr("shape_inference_graph", inference_graph_name);
     host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
+    TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
   }
   return Status::OK();
 }
@@ -1760,7 +1765,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
     const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
     FunctionLibraryDefinition* library,
     std::vector<TensorShapeProto>* static_shape_out,
-    std::unique_ptr<GraphDef>* graphdef_out) {
+    std::unique_ptr<Graph>* graph_out) {
   // Maps from nodes in graph_in to nodes in graph_out.
   //
   // When an edge has fully defined shape the source node in graph_in is
@@ -1777,8 +1782,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
   std::unordered_map<Node*, Node*> dummy_node_images;
   std::unordered_map<Node*, Node*> copied_node_images;
 
-  std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
-  graph_out->set_versions(graph_in.versions());
+  graph_out->reset(new Graph(graph_in.op_registry()));
+  (*graph_out)->set_versions(graph_in.versions());
   // The final input to the send node is the dynamic key, which we don't include
   // in the static shapes.
   static_shape_out->resize(send_node->num_inputs() - 1);
@@ -1800,7 +1805,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
     if (w.leave) {
       TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
           n, send_node, dummy_node_images, library, &copied_node_images,
-          graph_out.get()));
+          graph_out->get()));
     } else {
       if (visited[n->id()]) continue;
       visited[n->id()] = true;
@@ -1824,7 +1829,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
             context->ShapeHandleToProto(shape, &proto);
             if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
               dummy_node_images[src_node] = AddDummyShapedNode(
-                  src_node->output_type(src_port), proto, graph_out.get());
+                  src_node->output_type(src_port), proto, graph_out->get());
             }
             // The final input to the send node is the dynamic key, which we
             // don't include in the static shapes.
@@ -1849,7 +1854,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
           // The shapes of all the inputs to send_node are statically known. We
           // won't have to do any inference at compile time so return now: the
           // shapes were stored in static_shape_out above.
-          graphdef_out->reset();
+          graph_out->reset();
           return Status::OK();
         } else {
           // Any shape that is being processed is either the original send node
@@ -1872,9 +1877,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
     }
   }
 
-  graphdef_out->reset(new GraphDef());
-  graph_out->ToGraphDef(graphdef_out->get());
-
   return Status::OK();
 }
 
@@ -1997,13 +1999,14 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
   }
 
   for (auto& subgraph_entry : subgraphs_) {
+    const string& subgraph_name = subgraph_entry.first;
     Subgraph& subgraph = subgraph_entry.second;
     // Find all the recv_at_host nodes in this subgraph.
     std::vector<string> outside_compilation_names;
     subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
     std::unordered_set<string> recv_at_host_names;
-    for (const auto& name : outside_compilation_names) {
-      Node* recv_node = subgraph.GetRecvAtHostNode(name);
+    for (const auto& oc_name : outside_compilation_names) {
+      Node* recv_node = subgraph.GetRecvAtHostNode(oc_name);
       if (recv_node != nullptr) {
         recv_at_host_names.insert(recv_node->name());
       }
@@ -2012,26 +2015,30 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
     // without knowing the shape of the recv_at_host nodes, and store the
     // result, along with enough information to complete the job at compile time
     // once the recv_at_host shapes are known.
-    for (const auto& name : outside_compilation_names) {
-      Node* send_node = subgraph.GetSendFromHostNode(name);
+    for (const auto& oc_name : outside_compilation_names) {
+      Node* send_node = subgraph.GetSendFromHostNode(oc_name);
       std::vector<TensorShapeProto> static_shape;
-      std::unique_ptr<GraphDef> graphdef;
+      std::unique_ptr<Graph> graph;
       if (send_node != nullptr) {
         TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
             *pruned_graph, shape_refiner, recv_at_host_names,
-            node_images[send_node], library, &static_shape, &graphdef));
-        if (graphdef == nullptr) {
+            node_images[send_node], library, &static_shape, &graph));
+        if (graph == nullptr) {
           VLOG(2) << "Send node  " << send_node->name() << " shapes";
           for (int i = 0; i < static_shape.size(); ++i) {
             VLOG(2) << static_shape[i].DebugString();
           }
         } else {
-          VLOG(2) << "Send node " << send_node->name() << " graph\n"
-                  << graphdef->DebugString();
+          if (VLOG_IS_ON(2)) {
+            GraphDef graphdef;
+            graph->ToGraphDef(&graphdef);
+            VLOG(2) << "Send node " << send_node->name() << " graph\n"
+                    << graphdef.DebugString();
+          }
         }
       }
-      TF_RETURN_IF_ERROR(
-          subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
+      TF_RETURN_IF_ERROR(subgraph.AddShapeInferenceInfo(
+          subgraph_name, oc_name, static_shape, graph.get(), library));
     }
     if (!outside_compilation_names.empty()) {
       TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
index 711b142..94481a1 100644 (file)
@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include <memory>
 #include <utility>
 
 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
 
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/graph_to_functiondef.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
@@ -32,6 +34,24 @@ namespace {
 const char* const kXlaHostTransferSequencerAttr =
     "_xla_host_transfer_sequencer";
 
+Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder,
+                                    const string& name_suffix,
+                                    FunctionDefLibrary* library) {
+  GraphDef graphdef;
+  TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef));
+  std::unique_ptr<Graph> graph =
+      std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
+  GraphConstructorOptions opts;
+  opts.allow_internal_ops = true;
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get()));
+  FunctionDef* fdef = library->add_function();
+  TF_RETURN_IF_ERROR(GraphToFunctionDef(
+      *graph,
+      strings::StrCat("_outside_compilation_shape_inference_", name_suffix),
+      fdef));
+  return Status::OK();
+}
+
 template <class Tkey, class Tvalue>
 bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
                    const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
@@ -115,23 +135,7 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
       a.attr(), b.attr(), [](const string& s) { return s; },
       [](const AttrValue& v) { return v.DebugString(); },
       [](const string& key, const AttrValue& av, const AttrValue& bv) {
-        if (key == "shape_inference_graph") {
-          // Default serialization of GraphDef is unstable because maps don't
-          // serialize deterministically. Rather than go through the hoops to
-          // turn on deterministic serialization of this attr just for this
-          // test, add logic here to compare determinstically.
-          GraphDef ga;
-          if (!ga.ParseFromString(av.s())) {
-            return false;
-          }
-          GraphDef gb;
-          if (!gb.ParseFromString(bv.s())) {
-            return false;
-          }
-          return EqualGraphDef(ga, gb, nullptr);
-        } else {
-          return av.DebugString() == bv.DebugString();
-        }
+        return av.DebugString() == bv.DebugString();
       },
       strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
       diff);
@@ -848,7 +852,6 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
-  string shape_string_expected;
   {
     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
@@ -861,9 +864,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
                      shape.opts().WithName("E"));
     SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
                  {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
-    GraphDef shape_graph;
-    TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
-    EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+    TF_EXPECT_OK(
+        AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
   }
 
   *library_expected.add_function() = test::function::XTimesTwo();
@@ -883,7 +885,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
             {"key", "host_compute_channel_F1_O1"},
-            {"shape_inference_graph", shape_string_expected},
+            {"shape_inference_graph",
+             "_outside_compilation_shape_inference_F1_O1"},
             {"shapes", gtl::ArraySlice<DataType>({})}},
            {"c"}},
       },
@@ -969,7 +972,6 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
-  string shape_string_expected_1;
   {
     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
@@ -982,12 +984,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
                      shape1.opts().WithName("E"));
     SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
                  {e}, shape1.opts().WithName("outside_compilation_F1_O1_send"));
-    GraphDef shape1_graph;
-    TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
-    EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
+    TF_EXPECT_OK(
+        AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
   }
 
-  string shape_string_expected_2;
   {
     GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
@@ -1005,9 +1005,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
     Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
     SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2",
                  {h}, shape2.opts().WithName("outside_compilation_F1_O2_send"));
-    GraphDef shape2_graph;
-    TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
-    EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
+    TF_EXPECT_OK(
+        AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
   }
 
   *library_expected.add_function() = FunctionDefHelper::Create(
@@ -1029,7 +1028,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
             {"key", "host_compute_channel_F1_O2"},
-            {"shape_inference_graph", shape_string_expected_2},
+            {"shape_inference_graph",
+             "_outside_compilation_shape_inference_F1_O2"},
             {"shapes", gtl::ArraySlice<DataType>({})}},
            {"F"}},
           {{"outside_compilation_O1_host_compute"},
@@ -1038,7 +1038,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
             {"key", "host_compute_channel_F1_O1"},
-            {"shape_inference_graph", shape_string_expected_1},
+            {"shape_inference_graph",
+             "_outside_compilation_shape_inference_F1_O1"},
             {"shapes", gtl::ArraySlice<DataType>({})}},
            {"D"}},
       },
@@ -1134,7 +1135,6 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
-  string shape_string_expected;
   {
     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
@@ -1147,9 +1147,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
                      shape.opts().WithName("E"));
     SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
                  {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
-    GraphDef shape_graph;
-    TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
-    EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+    TF_EXPECT_OK(
+        AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
   }
 
   TensorShapeProto shape_proto_expected;
@@ -1172,7 +1171,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
             {"key", "host_compute_channel_F1_O1"},
-            {"shape_inference_graph", shape_string_expected},
+            {"shape_inference_graph",
+             "_outside_compilation_shape_inference_F1_O1"},
             {"shapes", gtl::ArraySlice<DataType>({})}},
            {"D"}},
       },
@@ -1661,7 +1661,6 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
   FunctionDefLibrary library_expected;
   GraphDef graphdef_expected;
 
-  string shape_string_expected;
   {
     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
     Node* key_constant =
@@ -1673,9 +1672,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
     Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
     SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1",
                  {e}, shape.opts().WithName("outside_compilation_F1_O1_send"));
-    GraphDef shape_graph;
-    TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
-    EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
+    TF_EXPECT_OK(
+        AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
   }
 
   *library_expected.add_function() = test::function::XTimesTwo();
@@ -1694,7 +1692,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
             {"key", "host_compute_channel_F1_O1"},
-            {"shape_inference_graph", shape_string_expected},
+            {"shape_inference_graph",
+             "_outside_compilation_shape_inference_F1_O1"},
             {"shapes", gtl::ArraySlice<DataType>({})}},
            {"c"}},
       },