From 85404e8f113c79dbeec5685166a4e797abffd505 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 23 May 2018 13:17:39 -0700 Subject: [PATCH] Adding utility class for manipulating a GraphDef. PiperOrigin-RevId: 197777416 --- tensorflow/core/grappler/optimizers/data/BUILD | 34 ++++ .../core/grappler/optimizers/data/graph_utils.cc | 217 +++++++++++++++++++++ .../core/grappler/optimizers/data/graph_utils.h | 81 ++++++++ .../grappler/optimizers/data/graph_utils_test.cc | 142 ++++++++++++++ 4 files changed, 474 insertions(+) create mode 100644 tensorflow/core/grappler/optimizers/data/BUILD create mode 100644 tensorflow/core/grappler/optimizers/data/graph_utils.cc create mode 100644 tensorflow/core/grappler/optimizers/data/graph_utils.h create mode 100644 tensorflow/core/grappler/optimizers/data/graph_utils_test.cc diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD new file mode 100644 index 0000000..29ebb9a --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -0,0 +1,34 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") + +cc_library( + name = "graph_utils", + srcs = ["graph_utils.cc"], + hdrs = [ + "graph_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "graph_utils_test", + srcs = ["graph_utils_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc new file mode 100644 index 0000000..df12de3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -0,0 +1,217 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/grappler_item_builder.h" +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace grappler { +namespace graph_utils { +namespace { + +int FindNodeWithPredicate(const std::function& predicate, + const GraphDef& graph) { + for (int i = 0; i < graph.node_size(); ++i) { + if (predicate(graph.node(i))) { + return i; + } + } + return -1; +} + +std::vector CreateNameIndex(const GraphDef& graph) { + std::map names; + for (int i = 0; i < graph.node_size(); ++i) { + names[graph.node(i).name()] = i; + } + std::vector index(graph.node_size()); + int i = 0; + for (const auto& pair : names) { + index[i++] = pair.second; + } + return index; +} + +std::vector CreateInputIndex(const NodeDef& node) { + std::map inputs; + for (int i = 0; i < node.input_size(); ++i) { + inputs[node.input(i)] = i; + } + std::vector index(node.input_size()); + int i = 0; + for (const auto& pair : inputs) { + index[i++] = pair.second; + } + return index; +} + +Status AddScalarConstNodeHelper( + DataType dtype, const std::function& add_value, + GraphDef* graph, NodeDef** result) { + NodeDef* node = graph->add_node(); + const string& name = strings::StrCat("Const/_", graph->node_size()); + node->set_name(name); + node->set_op("Const"); + (*node->mutable_attr())["dtype"].set_type(dtype); + std::unique_ptr tensor = + tensorflow::MakeUnique(); + std::unique_ptr tensor_shape = + tensorflow::MakeUnique(); + tensor->set_allocated_tensor_shape(tensor_shape.release()); + tensor->set_dtype(dtype); + add_value(tensor.get()); + (*node->mutable_attr())["value"].set_allocated_tensor(tensor.release()); + *result = node; + return Status::OK(); +} + +} // namespace + +Status AddNode(const string& name, const string& op, + const std::vector& inputs, + const std::vector>& attributes, + GraphDef* graph, NodeDef** result) { + NodeDef* node = graph->add_node(); + if (!name.empty()) { + node->set_name(name); + } else { + node->set_name(strings::StrCat(op, "/_", graph->node_size())); + } + node->set_op(op); + for (const string& input : inputs) { + node->add_input(input); + } + for (auto attr : attributes) { + (*node->mutable_attr())[attr.first] = attr.second; + } + *result = node; + return Status::OK(); +} + +template <> +Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result) { + return AddScalarConstNodeHelper( + DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph, + result); +} + +template <> +Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result) { + return AddScalarConstNodeHelper( + DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph, + result); +} + +template <> +Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result) { + return AddScalarConstNodeHelper( + DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph, + result); +} + +template <> +Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result) { + return AddScalarConstNodeHelper( + DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph, + result); +} + +template <> +Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result) { + return AddScalarConstNodeHelper( + DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph, + result); +} + +template <> +Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result) { + return AddScalarConstNodeHelper( + DT_STRING, + [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); }, + graph, result); +} + +bool Compare(const GraphDef& g1, const GraphDef& g2) { + if (g1.node_size() != g2.node_size()) { + return false; + } + std::vector name_index1 = CreateNameIndex(g1); + std::vector name_index2 = CreateNameIndex(g2); + for (int i = 0; i < g1.node_size(); ++i) { + int idx1 = name_index1[i]; + int idx2 = name_index2[i]; + if (g1.node(idx1).op() != g2.node(idx2).op()) { + return false; + } + if (g1.node(idx1).name() != g2.node(idx2).name()) { + return false; + } + if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) { + return false; + } + std::vector input_index1 = CreateInputIndex(g1.node(idx1)); + std::vector input_index2 = CreateInputIndex(g2.node(idx2)); + for (int j = 0; j < g1.node(idx1).input_size(); ++j) { + if (!IsSameInput(g1.node(idx1).input(input_index1[j]), + g2.node(idx2).input(input_index2[j]))) { + return false; + } + } + } + return true; +} + +bool ContainsNodeWithName(const string& name, const GraphDef& graph) { + return FindNodeWithName(name, graph) != -1; +} + +bool ContainsNodeWithOp(const string& op, const GraphDef& graph) { + return FindNodeWithOp(op, graph) != -1; +} + +Status DeleteNodes(const std::set& nodes_to_delete, GraphDef* graph) { + int last = graph->node_size() - 1; + for (int i = graph->node_size() - 1; i >= 0; --i) { + const NodeDef& node = graph->node(i); + if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) { + graph->mutable_node()->SwapElements(i, last); + last--; + } + } + graph->mutable_node()->DeleteSubrange(last + 1, + graph->node_size() - last - 1); + return Status::OK(); +} + +int FindNodeWithName(const string& name, const GraphDef& graph) { + return FindNodeWithPredicate( + [name](const NodeDef& node) { return node.name() == name; }, graph); +} + +int FindNodeWithOp(const string& op, const GraphDef& graph) { + return FindNodeWithPredicate( + [op](const NodeDef& node) { return node.op() == op; }, graph); +} + +} // end namespace graph_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h new file mode 100644 index 0000000..b40ca44 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace grappler { +namespace graph_utils { + +// Adds a node to the graph. +Status AddNode(const string& name, const string& op, + const std::vector& inputs, + const std::vector>& attributes, + GraphDef* graph, NodeDef** result); + +// Adds a Const node with the given value to the graph. +template +Status AddScalarConstNode(T v, GraphDef* graph, NodeDef** result) { + return errors::Unimplemented("Type %s is not supported.", + DataTypeToEnum::value); +} +template <> +Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result); +template <> +Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result); +template <> +Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result); +template <> +Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result); +template <> +Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result); +template <> +Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result); + +// Checks whether the two graphs are the same. +bool Compare(const GraphDef& g1, const GraphDef& g2); + +// Checks whether the graph contains a node with the given name. +bool ContainsNodeWithName(const string& name, const GraphDef& graph); + +// Checks whether the graph contains a node with the given op. +bool ContainsNodeWithOp(const string& op, const GraphDef& graph); + +// Deletes nodes from the graph. +Status DeleteNodes(const std::set& nodes_to_delete, GraphDef* graph); + +// Returns the index of the node with the given name or -1 if the node does +// not exist. +int FindNodeWithName(const string& name, const GraphDef& graph); + +// Returns the index of a node with the given op or -1 if no such node +// exists. +int FindNodeWithOp(const string& op, const GraphDef& graph); + +} // end namespace graph_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc new file mode 100644 index 0000000..b347260 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace graph_utils { +namespace { + +class GraphUtilsTest : public ::testing::Test {}; + +TEST_F(GraphUtilsTest, AddScalarConstNodeBool) { + GraphDef graph; + NodeDef* bool_node; + TF_EXPECT_OK(AddScalarConstNode(true, &graph, &bool_node)); + EXPECT_TRUE(ContainsNodeWithName(bool_node->name(), graph)); + EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true); +} + +TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) { + GraphDef graph; + NodeDef* double_node; + TF_EXPECT_OK(AddScalarConstNode(3.14, &graph, &double_node)); + EXPECT_TRUE(ContainsNodeWithName(double_node->name(), graph)); + EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14); +} + +TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) { + GraphDef graph; + NodeDef* float_node; + TF_EXPECT_OK(AddScalarConstNode(3.14, &graph, &float_node)); + EXPECT_TRUE(ContainsNodeWithName(float_node->name(), graph)); + EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14); +} + +TEST_F(GraphUtilsTest, AddScalarConstNodeInt) { + GraphDef graph; + NodeDef* int_node; + TF_EXPECT_OK(AddScalarConstNode(42, &graph, &int_node)); + EXPECT_TRUE(ContainsNodeWithName(int_node->name(), graph)); + EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42); +} + +TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) { + GraphDef graph; + NodeDef* int64_node; + TF_EXPECT_OK(AddScalarConstNode(42, &graph, &int64_node)); + EXPECT_TRUE(ContainsNodeWithName(int64_node->name(), graph)); + EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42); +} + +TEST_F(GraphUtilsTest, AddScalarConstNodeString) { + GraphDef graph; + NodeDef* string_node; + TF_EXPECT_OK(AddScalarConstNode("hello", &graph, &string_node)); + EXPECT_TRUE(ContainsNodeWithName(string_node->name(), graph)); + EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); +} + +TEST_F(GraphUtilsTest, Compare) { + GraphDef graphA; + GraphDef graphB; + EXPECT_TRUE(Compare(graphA, graphB)); + + NodeDef* nodeA; + TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graphA, &nodeA)); + NodeDef* nodeB; + TF_EXPECT_OK(AddNode("B", "OpB", {"A"}, {}, &graphA, &nodeB)); + EXPECT_FALSE(Compare(graphA, graphB)); + + graphB.mutable_node()->CopyFrom(graphA.node()); + EXPECT_TRUE(Compare(graphA, graphB)); +} + +TEST_F(GraphUtilsTest, ContainsNodeWithName) { + GraphDef graph; + EXPECT_TRUE(!ContainsNodeWithName("A", graph)); + + NodeDef* node; + TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); + EXPECT_TRUE(ContainsNodeWithName("A", graph)); + + TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); + EXPECT_TRUE(!ContainsNodeWithName("A", graph)); +} + +TEST_F(GraphUtilsTest, ContainsNodeWithOp) { + GraphDef graph; + EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph)); + + NodeDef* node; + TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); + EXPECT_TRUE(ContainsNodeWithOp("OpA", graph)); + + TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); + EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph)); +} + +TEST_F(GraphUtilsTest, FindNodeWithName) { + GraphDef graph; + EXPECT_EQ(FindNodeWithName("A", graph), -1); + + NodeDef* node; + TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); + EXPECT_NE(FindNodeWithName("A", graph), -1); + + TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); + EXPECT_EQ(FindNodeWithName("A", graph), -1); +} + +TEST_F(GraphUtilsTest, FindNodeWithOp) { + GraphDef graph; + EXPECT_EQ(FindNodeWithOp("OpA", graph), -1); + + NodeDef* node; + TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); + EXPECT_NE(FindNodeWithOp("OpA", graph), -1); + + TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); + EXPECT_EQ(FindNodeWithOp("OpA", graph), -1); +} + +} // namespace +} // namespace graph_utils +} // namespace grappler +} // namespace tensorflow -- 2.7.4