Adding utility class for manipulating a GraphDef.
authorJiri Simsa <jsimsa@google.com>
Wed, 23 May 2018 20:17:39 +0000 (13:17 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 20:19:46 +0000 (13:19 -0700)
PiperOrigin-RevId: 197777416

tensorflow/core/grappler/optimizers/data/BUILD [new file with mode: 0644]
tensorflow/core/grappler/optimizers/data/graph_utils.cc [new file with mode: 0644]
tensorflow/core/grappler/optimizers/data/graph_utils.h [new file with mode: 0644]
tensorflow/core/grappler/optimizers/data/graph_utils_test.cc [new file with mode: 0644]

diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
new file mode 100644 (file)
index 0000000..29ebb9a
--- /dev/null
@@ -0,0 +1,34 @@
+licenses(["notice"])  # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+
+cc_library(
+    name = "graph_utils",
+    srcs = ["graph_utils.cc"],
+    hdrs = [
+        "graph_utils.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/grappler:graph_view",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:grappler_item_builder",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/clusters:virtual_cluster",
+        "//tensorflow/core/grappler/optimizers:meta_optimizer",
+    ] + tf_protos_all(),
+)
+
+tf_cc_test(
+    name = "graph_utils_test",
+    srcs = ["graph_utils_test.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":graph_utils",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
new file mode 100644 (file)
index 0000000..df12de3
--- /dev/null
@@ -0,0 +1,217 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_utils {
+namespace {
+
+int FindNodeWithPredicate(const std::function<bool(const NodeDef&)>& predicate,
+                          const GraphDef& graph) {
+  for (int i = 0; i < graph.node_size(); ++i) {
+    if (predicate(graph.node(i))) {
+      return i;
+    }
+  }
+  return -1;
+}
+
+std::vector<int> CreateNameIndex(const GraphDef& graph) {
+  std::map<string, int> names;
+  for (int i = 0; i < graph.node_size(); ++i) {
+    names[graph.node(i).name()] = i;
+  }
+  std::vector<int> index(graph.node_size());
+  int i = 0;
+  for (const auto& pair : names) {
+    index[i++] = pair.second;
+  }
+  return index;
+}
+
+std::vector<int> CreateInputIndex(const NodeDef& node) {
+  std::map<string, int> inputs;
+  for (int i = 0; i < node.input_size(); ++i) {
+    inputs[node.input(i)] = i;
+  }
+  std::vector<int> index(node.input_size());
+  int i = 0;
+  for (const auto& pair : inputs) {
+    index[i++] = pair.second;
+  }
+  return index;
+}
+
+Status AddScalarConstNodeHelper(
+    DataType dtype, const std::function<void(TensorProto*)>& add_value,
+    GraphDef* graph, NodeDef** result) {
+  NodeDef* node = graph->add_node();
+  const string& name = strings::StrCat("Const/_", graph->node_size());
+  node->set_name(name);
+  node->set_op("Const");
+  (*node->mutable_attr())["dtype"].set_type(dtype);
+  std::unique_ptr<tensorflow::TensorProto> tensor =
+      tensorflow::MakeUnique<tensorflow::TensorProto>();
+  std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
+      tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
+  tensor->set_allocated_tensor_shape(tensor_shape.release());
+  tensor->set_dtype(dtype);
+  add_value(tensor.get());
+  (*node->mutable_attr())["value"].set_allocated_tensor(tensor.release());
+  *result = node;
+  return Status::OK();
+}
+
+}  // namespace
+
+Status AddNode(const string& name, const string& op,
+               const std::vector<string>& inputs,
+               const std::vector<std::pair<string, AttrValue>>& attributes,
+               GraphDef* graph, NodeDef** result) {
+  NodeDef* node = graph->add_node();
+  if (!name.empty()) {
+    node->set_name(name);
+  } else {
+    node->set_name(strings::StrCat(op, "/_", graph->node_size()));
+  }
+  node->set_op(op);
+  for (const string& input : inputs) {
+    node->add_input(input);
+  }
+  for (auto attr : attributes) {
+    (*node->mutable_attr())[attr.first] = attr.second;
+  }
+  *result = node;
+  return Status::OK();
+}
+
+template <>
+Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result) {
+  return AddScalarConstNodeHelper(
+      DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph,
+      result);
+}
+
+template <>
+Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result) {
+  return AddScalarConstNodeHelper(
+      DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph,
+      result);
+}
+
+template <>
+Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result) {
+  return AddScalarConstNodeHelper(
+      DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph,
+      result);
+}
+
+template <>
+Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result) {
+  return AddScalarConstNodeHelper(
+      DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph,
+      result);
+}
+
+template <>
+Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result) {
+  return AddScalarConstNodeHelper(
+      DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph,
+      result);
+}
+
+template <>
+Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result) {
+  return AddScalarConstNodeHelper(
+      DT_STRING,
+      [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
+      graph, result);
+}
+
+bool Compare(const GraphDef& g1, const GraphDef& g2) {
+  if (g1.node_size() != g2.node_size()) {
+    return false;
+  }
+  std::vector<int> name_index1 = CreateNameIndex(g1);
+  std::vector<int> name_index2 = CreateNameIndex(g2);
+  for (int i = 0; i < g1.node_size(); ++i) {
+    int idx1 = name_index1[i];
+    int idx2 = name_index2[i];
+    if (g1.node(idx1).op() != g2.node(idx2).op()) {
+      return false;
+    }
+    if (g1.node(idx1).name() != g2.node(idx2).name()) {
+      return false;
+    }
+    if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
+      return false;
+    }
+    std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
+    std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
+    for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
+      if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
+                       g2.node(idx2).input(input_index2[j]))) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+bool ContainsNodeWithName(const string& name, const GraphDef& graph) {
+  return FindNodeWithName(name, graph) != -1;
+}
+
+bool ContainsNodeWithOp(const string& op, const GraphDef& graph) {
+  return FindNodeWithOp(op, graph) != -1;
+}
+
+Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph) {
+  int last = graph->node_size() - 1;
+  for (int i = graph->node_size() - 1; i >= 0; --i) {
+    const NodeDef& node = graph->node(i);
+    if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) {
+      graph->mutable_node()->SwapElements(i, last);
+      last--;
+    }
+  }
+  graph->mutable_node()->DeleteSubrange(last + 1,
+                                        graph->node_size() - last - 1);
+  return Status::OK();
+}
+
+int FindNodeWithName(const string& name, const GraphDef& graph) {
+  return FindNodeWithPredicate(
+      [name](const NodeDef& node) { return node.name() == name; }, graph);
+}
+
+int FindNodeWithOp(const string& op, const GraphDef& graph) {
+  return FindNodeWithPredicate(
+      [op](const NodeDef& node) { return node.op() == op; }, graph);
+}
+
+}  // end namespace graph_utils
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
new file mode 100644 (file)
index 0000000..b40ca44
--- /dev/null
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_utils {
+
+// Adds a node to the graph.
+Status AddNode(const string& name, const string& op,
+               const std::vector<string>& inputs,
+               const std::vector<std::pair<string, AttrValue>>& attributes,
+               GraphDef* graph, NodeDef** result);
+
+// Adds a Const node with the given value to the graph.
+template <typename T>
+Status AddScalarConstNode(T v, GraphDef* graph, NodeDef** result) {
+  return errors::Unimplemented("Type %s is not supported.",
+                               DataTypeToEnum<T>::value);
+}
+template <>
+Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result);
+
+// Checks whether the two graphs are the same.
+bool Compare(const GraphDef& g1, const GraphDef& g2);
+
+// Checks whether the graph contains a node with the given name.
+bool ContainsNodeWithName(const string& name, const GraphDef& graph);
+
+// Checks whether the graph contains a node with the given op.
+bool ContainsNodeWithOp(const string& op, const GraphDef& graph);
+
+// Deletes nodes from the graph.
+Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph);
+
+// Returns the index of the node with the given name or -1 if the node does
+// not exist.
+int FindNodeWithName(const string& name, const GraphDef& graph);
+
+// Returns the index of a node with the given op or -1 if no such  node
+// exists.
+int FindNodeWithOp(const string& op, const GraphDef& graph);
+
+}  // end namespace graph_utils
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
new file mode 100644 (file)
index 0000000..b347260
--- /dev/null
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_utils {
+namespace {
+
+class GraphUtilsTest : public ::testing::Test {};
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeBool) {
+  GraphDef graph;
+  NodeDef* bool_node;
+  TF_EXPECT_OK(AddScalarConstNode<bool>(true, &graph, &bool_node));
+  EXPECT_TRUE(ContainsNodeWithName(bool_node->name(), graph));
+  EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) {
+  GraphDef graph;
+  NodeDef* double_node;
+  TF_EXPECT_OK(AddScalarConstNode<double>(3.14, &graph, &double_node));
+  EXPECT_TRUE(ContainsNodeWithName(double_node->name(), graph));
+  EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) {
+  GraphDef graph;
+  NodeDef* float_node;
+  TF_EXPECT_OK(AddScalarConstNode<float>(3.14, &graph, &float_node));
+  EXPECT_TRUE(ContainsNodeWithName(float_node->name(), graph));
+  EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeInt) {
+  GraphDef graph;
+  NodeDef* int_node;
+  TF_EXPECT_OK(AddScalarConstNode<int>(42, &graph, &int_node));
+  EXPECT_TRUE(ContainsNodeWithName(int_node->name(), graph));
+  EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) {
+  GraphDef graph;
+  NodeDef* int64_node;
+  TF_EXPECT_OK(AddScalarConstNode<int64>(42, &graph, &int64_node));
+  EXPECT_TRUE(ContainsNodeWithName(int64_node->name(), graph));
+  EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeString) {
+  GraphDef graph;
+  NodeDef* string_node;
+  TF_EXPECT_OK(AddScalarConstNode<StringPiece>("hello", &graph, &string_node));
+  EXPECT_TRUE(ContainsNodeWithName(string_node->name(), graph));
+  EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
+}
+
+TEST_F(GraphUtilsTest, Compare) {
+  GraphDef graphA;
+  GraphDef graphB;
+  EXPECT_TRUE(Compare(graphA, graphB));
+
+  NodeDef* nodeA;
+  TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graphA, &nodeA));
+  NodeDef* nodeB;
+  TF_EXPECT_OK(AddNode("B", "OpB", {"A"}, {}, &graphA, &nodeB));
+  EXPECT_FALSE(Compare(graphA, graphB));
+
+  graphB.mutable_node()->CopyFrom(graphA.node());
+  EXPECT_TRUE(Compare(graphA, graphB));
+}
+
+TEST_F(GraphUtilsTest, ContainsNodeWithName) {
+  GraphDef graph;
+  EXPECT_TRUE(!ContainsNodeWithName("A", graph));
+
+  NodeDef* node;
+  TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+  EXPECT_TRUE(ContainsNodeWithName("A", graph));
+
+  TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+  EXPECT_TRUE(!ContainsNodeWithName("A", graph));
+}
+
+TEST_F(GraphUtilsTest, ContainsNodeWithOp) {
+  GraphDef graph;
+  EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
+
+  NodeDef* node;
+  TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+  EXPECT_TRUE(ContainsNodeWithOp("OpA", graph));
+
+  TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+  EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
+}
+
+TEST_F(GraphUtilsTest, FindNodeWithName) {
+  GraphDef graph;
+  EXPECT_EQ(FindNodeWithName("A", graph), -1);
+
+  NodeDef* node;
+  TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+  EXPECT_NE(FindNodeWithName("A", graph), -1);
+
+  TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+  EXPECT_EQ(FindNodeWithName("A", graph), -1);
+}
+
+TEST_F(GraphUtilsTest, FindNodeWithOp) {
+  GraphDef graph;
+  EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
+
+  NodeDef* node;
+  TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+  EXPECT_NE(FindNodeWithOp("OpA", graph), -1);
+
+  TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+  EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
+}
+
+}  // namespace
+}  // namespace graph_utils
+}  // namespace grappler
+}  // namespace tensorflow