#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
return Status::OK();
}
+// Replace any unknown dimensions in a shape with
+// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
+Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
+ const TensorShapeProto& shape_pb_in,
+ TensorShapeProto* shape_pb_out,
+ TensorShape* shape_out) {
+ std::vector<int32> dims;
+ for (const auto& dim_proto : shape_pb_in.dim()) {
+ if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
+ dim_proto.size() == -1) {
+ dims.push_back(cfg.placeholder_unknown_output_shape_dim);
+ shape_pb_out->add_dim()->set_size(
+ cfg.placeholder_unknown_output_shape_dim);
+ } else {
+ dims.push_back(std::max<int32>(1, dim_proto.size()));
+ shape_pb_out->add_dim()->set_size(dim_proto.size());
+ }
+ }
+ return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
+}
+
} // namespace
// static
}
}
- // Detect feed and fetch nodes from signature defs.
+ // Detect feed and fetch nodes from signature defs. Signatures may share same
+ // inputs or outputs.
+ std::unordered_set<string> signature_feed_nodes;
+ std::unordered_set<string> signature_fetch_nodes;
for (const auto& name_and_signature : meta_graph.signature_def()) {
for (const auto& name_and_input : name_and_signature.second.inputs()) {
const TensorInfo& input = name_and_input.second;
if (input.has_coo_sparse()) {
// Define the shapes following the comment of CooSparse.
- PartialTensorShape partial_shape_1d({-1});
- PartialTensorShape partial_shape_2d({-1, -1});
- TensorShape shape_1d;
- TensorShape shape_2d;
- if (!partial_shape_1d.AsTensorShape(&shape_1d) ||
- !partial_shape_2d.AsTensorShape(&shape_2d)) {
- LOG(ERROR) << "Internal error when constructing tensor shapes.";
- return nullptr;
+ // TODO(yuefengz): we probably want to use different dim values for the
+ // three tensors of a SparseTensor.
+ int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
+ TensorShape shape_1d({dim});
+ TensorShape shape_2d({dim, dim});
+
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().values_tensor_name()))) {
+ Tensor value_tensor(input.dtype(), shape_1d);
+ InitializeTensor(input.dtype(), &value_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().indices_tensor_name()))) {
+ Tensor indices_tensor(DT_INT64, shape_2d);
+ InitializeTensor(input.dtype(), &indices_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().indices_tensor_name()),
+ indices_tensor);
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
+ Tensor dense_shape_tensor(DT_INT64, shape_1d);
+ InitializeTensor(input.dtype(), &dense_shape_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().dense_shape_tensor_name()),
+ dense_shape_tensor);
}
-
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().values_tensor_name()),
- Tensor(input.dtype(), shape_1d));
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().indices_tensor_name()),
- Tensor(DT_INT64, shape_2d));
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().dense_shape_tensor_name()),
- Tensor(DT_INT64, shape_1d));
} else {
- new_item->feed.emplace_back(
- NodeName(input.name()),
- Tensor(input.dtype(), input.tensor_shape()));
+ if (gtl::InsertIfNotPresent(&signature_feed_nodes,
+ NodeName(input.name()))) {
+ TensorShape shape;
+ TensorShapeProto shape_proto;
+ Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
+ &shape_proto, &shape);
+ if (!s.ok()) {
+ LOG(ERROR) << "Invalid shape for signature input " << input.name()
+ << ": " << s << ", skipping this input";
+ return nullptr;
+ }
+
+ Tensor fake_input(input.dtype(), shape);
+ InitializeTensor(input.dtype(), &fake_input);
+ new_item->feed.emplace_back(NodeName(input.name()), fake_input);
+ }
}
}
for (const auto& name_and_output : name_and_signature.second.outputs()) {
const TensorInfo& output = name_and_output.second;
if (output.has_coo_sparse()) {
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().values_tensor_name()));
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().indices_tensor_name()));
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().values_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().values_tensor_name()));
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().indices_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().indices_tensor_name()));
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ }
} else {
- new_item->fetch.push_back(NodeName(output.name()));
+ if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
+ NodeName(output.name()))) {
+ new_item->fetch.push_back(NodeName(output.name()));
+ }
}
}
}
// shape is not empty if the shape is partially defined.
TensorShape shape;
TensorShapeProto shape_proto;
- std::vector<int32> dims;
- for (const auto& dim_proto : node.attr().at("shape").shape().dim()) {
- if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
- dim_proto.size() == -1) {
- dims.push_back(cfg.placeholder_unknown_output_shape_dim);
- shape_proto.add_dim()->set_size(
- cfg.placeholder_unknown_output_shape_dim);
- } else {
- dims.push_back(std::max<int32>(1, dim_proto.size()));
- shape_proto.add_dim()->set_size(dim_proto.size());
- }
- }
- Status make_shape_status =
- TensorShapeUtils::MakeShape(dims.data(), dims.size(), &shape);
+ Status make_shape_status = ReplaceUnknownShapeDim(
+ cfg, node.attr().at("shape").shape(), &shape_proto, &shape);
if (!make_shape_status.ok()) {
LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": "
<< make_shape_status << ", skipping this input";
if (cfg.feed_nodes.empty()) {
// No specific feed nodes were given. Assume all placeholders are fed.
- new_item->feed.emplace_back(node.name(), fake_input);
+ if (signature_feed_nodes.count(node.name()) == 0) {
+ new_item->feed.emplace_back(node.name(), fake_input);
+ }
} else if (cfg.feed_nodes.count(node.name()) > 0) {
// If specific feed nodes were given, only update their tensors.
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
(*serving_signature.mutable_outputs())["output"] = output;
(*meta_graph.mutable_signature_def())["serving"] = serving_signature;
+ // It should be able to dedup the input and output with same names.
+ TensorInfo input2, output2;
+ input.set_name("x");
+ input.set_dtype(DT_FLOAT);
+ output.set_name("z");
+ SignatureDef serving_signature2;
+ (*serving_signature.mutable_inputs())["input2"] = input2;
+ (*serving_signature.mutable_outputs())["output2"] = output2;
+ (*meta_graph.mutable_signature_def())["serving2"] = serving_signature2;
+
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
ASSERT_TRUE(item != nullptr);
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
EXPECT_EQ(item->feed[0].first, "x");
EXPECT_EQ(item->fetch[0], "z");
}
ASSERT_TRUE(item == nullptr);
}
+TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto shape_1d = PartialTensorShape({-1});
+ auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(shape_1d));
+ auto y = ops::Const(s.WithOpName("y"), static_cast<float>(1.0));
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ TensorInfo input, output;
+ input.set_name("x");
+ input.set_dtype(DT_FLOAT);
+ shape_1d.AsProto(input.mutable_tensor_shape());
+ output.set_name("z");
+
+ SignatureDef serving_signature;
+ (*serving_signature.mutable_inputs())["input"] = input;
+ (*serving_signature.mutable_outputs())["output"] = output;
+ (*meta_graph.mutable_signature_def())["serving"] = serving_signature;
+
+ ItemConfig cfg;
+ cfg.placeholder_unknown_output_shape_dim = 64;
+ std::unique_ptr<GrapplerItem> item1 =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
+ ASSERT_TRUE(item1 != nullptr);
+
+ ASSERT_EQ(item1->feed.size(), 1);
+ EXPECT_EQ(item1->feed[0].second.NumElements(), 64);
+
+ std::unique_ptr<GrapplerItem> item2 =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
+ ASSERT_TRUE(item2 != nullptr);
+
+ ASSERT_EQ(item2->feed.size(), 1);
+ EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow