From b05a6b5c4cb685b19b8c09693d40d4743af79dea Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 28 May 2018 09:33:49 -0700 Subject: [PATCH] Adding tf.data optimization for rewriting `map(...).batch(...)` to `map_and_batch(...)`. PiperOrigin-RevId: 198310806 --- tensorflow/core/grappler/optimizers/data/BUILD | 42 +++++ .../optimizers/data/map_and_batch_fusion.cc | 133 +++++++++++++++ .../optimizers/data/map_and_batch_fusion.h | 46 ++++++ .../optimizers/data/map_and_batch_fusion_test.cc | 184 +++++++++++++++++++++ 4 files changed, 405 insertions(+) create mode 100644 tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc create mode 100644 tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h create mode 100644 tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 29ebb9a..d3fe7df 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -32,3 +32,45 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "map_and_batch_fusion", + srcs = ["map_and_batch_fusion.cc"], + hdrs = [ + "map_and_batch_fusion.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "map_and_batch_fusion_test", + srcs = ["map_and_batch_fusion_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + ":map_and_batch_fusion", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + ], +) + +cc_library( + name = "data", + visibility = ["//visibility:public"], + deps = [ + ":map_and_batch_fusion", + ], + alwayslink = 1, +) diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc new file mode 100644 index 0000000..5b8df61 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { + +Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + GraphView graph(output); + std::set nodes_to_delete; + for (const NodeDef& node : item.graph.node()) { + if (node.op() != "BatchDataset") { + continue; + } + + // Use a more descriptive variable name now that we now the node type. + NodeDef batch_node(node); + GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0); + NodeDef* node2 = graph.GetRegularFanin(input_port).node; + if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") { + continue; + } + + // Use a more descriptive variable name now that we now the node type. + NodeDef* map_node = node2; + NodeDef* new_node = output->mutable_node()->Add(); + new_node->set_op("MapAndBatchDatasetV2"); + new_node->set_name( + strings::StrCat("MapAndBatchDatasetV2/_", output->node_size())); + + // Set the `input` input argument. + new_node->add_input(map_node->input(0)); + + // Set the `other_arguments` input arguments. + int num_other_args; + if (map_node->op() == "ParallelMapDataset") { + num_other_args = map_node->input_size() - 2; + } else { + num_other_args = map_node->input_size() - 1; + } + for (int i = 0; i < num_other_args; i++) { + new_node->add_input(map_node->input(i + 1)); + } + + // Set the `batch_size` input argument. + new_node->add_input(batch_node.input(1)); + + // Set the `num_parallel_calls` input argument. + if (map_node->op() == "ParallelMapDataset") { + // The type of the `num_parallel_calls` argument in ParallelMapDataset + // and MapAndBatchDataset is different (int32 and int64 respectively) + // so we cannot reuse the same Const node and thus create a new one. + NodeDef* v = graph.GetNode(map_node->input(map_node->input_size() - 1)); + NodeDef* tmp; + TF_RETURN_IF_ERROR(graph_utils::AddScalarConstNode( + v->attr().at("value").tensor().int_val(0), output, &tmp)); + new_node->add_input(tmp->name()); + } else { + NodeDef* tmp; + TF_RETURN_IF_ERROR( + graph_utils::AddScalarConstNode(1, output, &tmp)); + new_node->add_input(tmp->name()); + } + + // Set the `drop_remainder` input argument. + { + NodeDef* tmp; + TF_RETURN_IF_ERROR( + graph_utils::AddScalarConstNode(false, output, &tmp)); + new_node->add_input(tmp->name()); + } + + // Set `f` and `Targuments` attributes. + new_node->mutable_attr()->insert(map_node->attr().begin(), + map_node->attr().end()); + // Set `output_types` and `output_shapes` attributes. + new_node->mutable_attr()->insert(batch_node.attr().begin(), + batch_node.attr().end()); + + // Mark the `Map` and `Batch` nodes for removal. + nodes_to_delete.insert(map_node->name()); + nodes_to_delete.insert(batch_node.name()); + + // Update the input of the outputs of the `Batch` node to use + // `MapAndBatch`. + GraphView::OutputPort output_port = + graph.GetOutputPort(batch_node.name(), 0); + auto fanout = graph.GetFanout(output_port); + for (auto it = fanout.begin(); it != fanout.end(); ++it) { + NodeDef* node = it->node; + node->set_input(0, new_node->name()); + } + } + TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output)); + return Status::OK(); +} + +void MapAndBatchFusion::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchFusion, "map_and_batch_fusion"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h new file mode 100644 index 0000000..a5a4d91 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +class MapAndBatchFusion : public CustomGraphOptimizer { + public: + MapAndBatchFusion() {} + ~MapAndBatchFusion() override {} + + string name() const override { return "map_and_batch_fusion"; }; + + Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config = + nullptr) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_ diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc new file mode 100644 index 0000000..51e7f37 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h" + +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { + std::vector> empty_attributes; + + GrapplerItem item; + GraphDef *graph = &item.graph; + NodeDef *start_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(0, graph, &start_node)); + NodeDef *stop_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(10, graph, &stop_node)); + NodeDef *step_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(1, graph, &step_node)); + + std::vector range_inputs(3); + range_inputs[0] = start_node->name(); + range_inputs[1] = stop_node->name(); + range_inputs[2] = step_node->name(); + NodeDef *range_node; + TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, + empty_attributes, graph, &range_node)); + NodeDef *captured_input_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode( + "hello", graph, &captured_input_node)); + + std::vector map_inputs(2); + map_inputs[0] = range_node->name(); + map_inputs[1] = captured_input_node->name(); + NodeDef *map_node; + TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, + empty_attributes, graph, &map_node)); + + NodeDef *batch_size_node; + TF_ASSERT_OK( + graph_utils::AddScalarConstNode(5, graph, &batch_size_node)); + std::vector batch_inputs(2); + batch_inputs[0] = map_node->name(); + batch_inputs[1] = batch_size_node->name(); + NodeDef *batch_node; + TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, + empty_attributes, graph, &batch_node)); + + MapAndBatchFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output)); + EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output)); + NodeDef map_and_batch_node = + output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output)); + EXPECT_EQ(map_and_batch_node.input_size(), 5); + EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0)); + EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1)); + EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1)); + NodeDef num_parallel_calls_node = output.node( + graph_utils::FindNodeWithName(map_and_batch_node.input(3), output)); + EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0), + 1); + NodeDef drop_remainder_node = output.node( + graph_utils::FindNodeWithName(map_and_batch_node.input(4), output)); + EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false); +} + +TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { + std::vector> empty_attributes; + + GrapplerItem item; + GraphDef *graph = &item.graph; + NodeDef *start_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(0, graph, &start_node)); + NodeDef *stop_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(10, graph, &stop_node)); + NodeDef *step_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(1, graph, &step_node)); + + std::vector range_inputs(3); + range_inputs[0] = start_node->name(); + range_inputs[1] = stop_node->name(); + range_inputs[2] = step_node->name(); + NodeDef *range_node; + TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, + empty_attributes, graph, &range_node)); + NodeDef *captured_input_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode( + "hello", graph, &captured_input_node)); + NodeDef *num_parallel_calls_node; + TF_ASSERT_OK( + graph_utils::AddScalarConstNode(2, graph, &num_parallel_calls_node)); + + std::vector map_inputs(3); + map_inputs[0] = range_node->name(); + map_inputs[1] = captured_input_node->name(); + map_inputs[2] = num_parallel_calls_node->name(); + NodeDef *map_node; + TF_ASSERT_OK(graph_utils::AddNode("", "ParallelMapDataset", map_inputs, + empty_attributes, graph, &map_node)); + + NodeDef *batch_size_node; + TF_ASSERT_OK( + graph_utils::AddScalarConstNode(5, graph, &batch_size_node)); + std::vector batch_inputs(2); + batch_inputs[0] = map_node->name(); + batch_inputs[1] = batch_size_node->name(); + NodeDef *batch_node; + TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, + empty_attributes, graph, &batch_node)); + + MapAndBatchFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output)); + EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output)); + NodeDef map_and_batch_node = + output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output)); + EXPECT_EQ(map_and_batch_node.input_size(), 5); + EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0)); + EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1)); + EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1)); + NodeDef num_parallel_calls_node2 = output.node( + graph_utils::FindNodeWithName(map_and_batch_node.input(3), output)); + EXPECT_EQ(num_parallel_calls_node2.attr().at("value").tensor().int64_val(0), + 2); + NodeDef drop_remainder_node = output.node( + graph_utils::FindNodeWithName(map_and_batch_node.input(4), output)); + EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false); +} + +TEST(MapAndBatchFusionTest, NoChange) { + std::vector> empty_attributes; + + GrapplerItem item; + GraphDef *graph = &item.graph; + NodeDef *start_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(0, graph, &start_node)); + NodeDef *stop_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(10, graph, &stop_node)); + NodeDef *step_node; + TF_ASSERT_OK(graph_utils::AddScalarConstNode(1, graph, &step_node)); + + std::vector range_inputs(3); + range_inputs[0] = start_node->name(); + range_inputs[1] = stop_node->name(); + range_inputs[2] = step_node->name(); + NodeDef *range_node; + TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, + empty_attributes, graph, &range_node)); + + MapAndBatchFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_TRUE(graph_utils::Compare(*graph, output)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow -- 2.7.4