Replace the unknown dimension of signature input when building grappler items.
authorYuefeng Zhou <yuefengz@google.com>
Tue, 13 Mar 2018 18:27:46 +0000 (11:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 18:32:37 +0000 (11:32 -0700)
Fix the bug where same feed nodes or fetch nodes would be added more than once.

PiperOrigin-RevId: 188902101

tensorflow/core/grappler/grappler_item_builder.cc
tensorflow/core/grappler/grappler_item_builder_test.cc

index 04c7dae..d7b3003 100644 (file)
@@ -38,6 +38,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/protobuf_internal.h"
 #include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -152,6 +153,27 @@ Status PruneGraph(GrapplerItem* item) {
   return Status::OK();
 }
 
+// Replace any unknown dimensions in a shape with
+// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
+Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
+                              const TensorShapeProto& shape_pb_in,
+                              TensorShapeProto* shape_pb_out,
+                              TensorShape* shape_out) {
+  std::vector<int32> dims;
+  for (const auto& dim_proto : shape_pb_in.dim()) {
+    if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
+        dim_proto.size() == -1) {
+      dims.push_back(cfg.placeholder_unknown_output_shape_dim);
+      shape_pb_out->add_dim()->set_size(
+          cfg.placeholder_unknown_output_shape_dim);
+    } else {
+      dims.push_back(std::max<int32>(1, dim_proto.size()));
+      shape_pb_out->add_dim()->set_size(dim_proto.size());
+    }
+  }
+  return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
+}
+
 }  // namespace
 
 // static
@@ -181,48 +203,92 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
     }
   }
 
-  // Detect feed and fetch nodes from signature defs.
+  // Detect feed and fetch nodes from signature defs. Signatures may share same
+  // inputs or outputs.
+  std::unordered_set<string> signature_feed_nodes;
+  std::unordered_set<string> signature_fetch_nodes;
   for (const auto& name_and_signature : meta_graph.signature_def()) {
     for (const auto& name_and_input : name_and_signature.second.inputs()) {
       const TensorInfo& input = name_and_input.second;
       if (input.has_coo_sparse()) {
         // Define the shapes following the comment of CooSparse.
-        PartialTensorShape partial_shape_1d({-1});
-        PartialTensorShape partial_shape_2d({-1, -1});
-        TensorShape shape_1d;
-        TensorShape shape_2d;
-        if (!partial_shape_1d.AsTensorShape(&shape_1d) ||
-            !partial_shape_2d.AsTensorShape(&shape_2d)) {
-          LOG(ERROR) << "Internal error when constructing tensor shapes.";
-          return nullptr;
+        // TODO(yuefengz): we probably want to use different dim values for the
+        // three tensors of a SparseTensor.
+        int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
+        TensorShape shape_1d({dim});
+        TensorShape shape_2d({dim, dim});
+
+        if (gtl::InsertIfNotPresent(
+                &signature_feed_nodes,
+                NodeName(input.coo_sparse().values_tensor_name()))) {
+          Tensor value_tensor(input.dtype(), shape_1d);
+          InitializeTensor(input.dtype(), &value_tensor);
+          new_item->feed.emplace_back(
+              NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
+        }
+        if (gtl::InsertIfNotPresent(
+                &signature_feed_nodes,
+                NodeName(input.coo_sparse().indices_tensor_name()))) {
+          Tensor indices_tensor(DT_INT64, shape_2d);
+          InitializeTensor(input.dtype(), &indices_tensor);
+          new_item->feed.emplace_back(
+              NodeName(input.coo_sparse().indices_tensor_name()),
+              indices_tensor);
+        }
+        if (gtl::InsertIfNotPresent(
+                &signature_feed_nodes,
+                NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
+          Tensor dense_shape_tensor(DT_INT64, shape_1d);
+          InitializeTensor(input.dtype(), &dense_shape_tensor);
+          new_item->feed.emplace_back(
+              NodeName(input.coo_sparse().dense_shape_tensor_name()),
+              dense_shape_tensor);
         }
-
-        new_item->feed.emplace_back(
-            NodeName(input.coo_sparse().values_tensor_name()),
-            Tensor(input.dtype(), shape_1d));
-        new_item->feed.emplace_back(
-            NodeName(input.coo_sparse().indices_tensor_name()),
-            Tensor(DT_INT64, shape_2d));
-        new_item->feed.emplace_back(
-            NodeName(input.coo_sparse().dense_shape_tensor_name()),
-            Tensor(DT_INT64, shape_1d));
       } else {
-        new_item->feed.emplace_back(
-            NodeName(input.name()),
-            Tensor(input.dtype(), input.tensor_shape()));
+        if (gtl::InsertIfNotPresent(&signature_feed_nodes,
+                                    NodeName(input.name()))) {
+          TensorShape shape;
+          TensorShapeProto shape_proto;
+          Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
+                                            &shape_proto, &shape);
+          if (!s.ok()) {
+            LOG(ERROR) << "Invalid shape for signature input " << input.name()
+                       << ": " << s << ", skipping this input";
+            return nullptr;
+          }
+
+          Tensor fake_input(input.dtype(), shape);
+          InitializeTensor(input.dtype(), &fake_input);
+          new_item->feed.emplace_back(NodeName(input.name()), fake_input);
+        }
       }
     }
     for (const auto& name_and_output : name_and_signature.second.outputs()) {
       const TensorInfo& output = name_and_output.second;
       if (output.has_coo_sparse()) {
-        new_item->fetch.push_back(
-            NodeName(output.coo_sparse().values_tensor_name()));
-        new_item->fetch.push_back(
-            NodeName(output.coo_sparse().indices_tensor_name()));
-        new_item->fetch.push_back(
-            NodeName(output.coo_sparse().dense_shape_tensor_name()));
+        if (gtl::InsertIfNotPresent(
+                &signature_fetch_nodes,
+                NodeName(output.coo_sparse().values_tensor_name()))) {
+          new_item->fetch.push_back(
+              NodeName(output.coo_sparse().values_tensor_name()));
+        }
+        if (gtl::InsertIfNotPresent(
+                &signature_fetch_nodes,
+                NodeName(output.coo_sparse().indices_tensor_name()))) {
+          new_item->fetch.push_back(
+              NodeName(output.coo_sparse().indices_tensor_name()));
+        }
+        if (gtl::InsertIfNotPresent(
+                &signature_fetch_nodes,
+                NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
+          new_item->fetch.push_back(
+              NodeName(output.coo_sparse().dense_shape_tensor_name()));
+        }
       } else {
-        new_item->fetch.push_back(NodeName(output.name()));
+        if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
+                                    NodeName(output.name()))) {
+          new_item->fetch.push_back(NodeName(output.name()));
+        }
       }
     }
   }
@@ -377,20 +443,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
       // shape is not empty if the shape is partially defined.
       TensorShape shape;
       TensorShapeProto shape_proto;
-      std::vector<int32> dims;
-      for (const auto& dim_proto : node.attr().at("shape").shape().dim()) {
-        if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
-            dim_proto.size() == -1) {
-          dims.push_back(cfg.placeholder_unknown_output_shape_dim);
-          shape_proto.add_dim()->set_size(
-              cfg.placeholder_unknown_output_shape_dim);
-        } else {
-          dims.push_back(std::max<int32>(1, dim_proto.size()));
-          shape_proto.add_dim()->set_size(dim_proto.size());
-        }
-      }
-      Status make_shape_status =
-          TensorShapeUtils::MakeShape(dims.data(), dims.size(), &shape);
+      Status make_shape_status = ReplaceUnknownShapeDim(
+          cfg, node.attr().at("shape").shape(), &shape_proto, &shape);
       if (!make_shape_status.ok()) {
         LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": "
                    << make_shape_status << ", skipping this input";
@@ -430,7 +484,9 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
 
       if (cfg.feed_nodes.empty()) {
         // No specific feed nodes were given. Assume all placeholders are fed.
-        new_item->feed.emplace_back(node.name(), fake_input);
+        if (signature_feed_nodes.count(node.name()) == 0) {
+          new_item->feed.emplace_back(node.name(), fake_input);
+        }
       } else if (cfg.feed_nodes.count(node.name()) > 0) {
         // If specific feed nodes were given, only update their tensors.
         auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
index ada9092..29488e4 100644 (file)
@@ -319,10 +319,22 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithSignatureDef) {
   (*serving_signature.mutable_outputs())["output"] = output;
   (*meta_graph.mutable_signature_def())["serving"] = serving_signature;
 
+  // It should be able to dedup the input and output with same names.
+  TensorInfo input2, output2;
+  input.set_name("x");
+  input.set_dtype(DT_FLOAT);
+  output.set_name("z");
+  SignatureDef serving_signature2;
+  (*serving_signature.mutable_inputs())["input2"] = input2;
+  (*serving_signature.mutable_outputs())["output2"] = output2;
+  (*meta_graph.mutable_signature_def())["serving2"] = serving_signature2;
+
   std::unique_ptr<GrapplerItem> item =
       GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
   ASSERT_TRUE(item != nullptr);
 
+  EXPECT_EQ(item->feed.size(), 1);
+  EXPECT_EQ(item->fetch.size(), 1);
   EXPECT_EQ(item->feed[0].first, "x");
   EXPECT_EQ(item->fetch[0], "z");
 }
@@ -354,6 +366,45 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithIncompleteSignatureDef) {
   ASSERT_TRUE(item == nullptr);
 }
 
+TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  auto shape_1d = PartialTensorShape({-1});
+  auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+                            ops::Placeholder::Shape(shape_1d));
+  auto y = ops::Const(s.WithOpName("y"), static_cast<float>(1.0));
+  auto z = ops::Add(s.WithOpName("z"), x, y);
+
+  MetaGraphDef meta_graph;
+  TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+  TensorInfo input, output;
+  input.set_name("x");
+  input.set_dtype(DT_FLOAT);
+  shape_1d.AsProto(input.mutable_tensor_shape());
+  output.set_name("z");
+
+  SignatureDef serving_signature;
+  (*serving_signature.mutable_inputs())["input"] = input;
+  (*serving_signature.mutable_outputs())["output"] = output;
+  (*meta_graph.mutable_signature_def())["serving"] = serving_signature;
+
+  ItemConfig cfg;
+  cfg.placeholder_unknown_output_shape_dim = 64;
+  std::unique_ptr<GrapplerItem> item1 =
+      GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
+  ASSERT_TRUE(item1 != nullptr);
+
+  ASSERT_EQ(item1->feed.size(), 1);
+  EXPECT_EQ(item1->feed[0].second.NumElements(), 64);
+
+  std::unique_ptr<GrapplerItem> item2 =
+      GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
+  ASSERT_TRUE(item2 != nullptr);
+
+  ASSERT_EQ(item2->feed.size(), 1);
+  EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow