Adding tf.data optimization for rewriting `map(...).batch(...)` to `map_and_batch...
authorJiri Simsa <jsimsa@google.com>
Mon, 28 May 2018 16:33:49 +0000 (09:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 28 May 2018 16:36:27 +0000 (09:36 -0700)
PiperOrigin-RevId: 198310806

tensorflow/core/grappler/optimizers/data/BUILD
tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc [new file with mode: 0644]
tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h [new file with mode: 0644]
tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc [new file with mode: 0644]

index 29ebb9a..d3fe7df 100644 (file)
@@ -32,3 +32,45 @@ tf_cc_test(
         "//tensorflow/core:test_main",
     ],
 )
+
+cc_library(
+    name = "map_and_batch_fusion",
+    srcs = ["map_and_batch_fusion.cc"],
+    hdrs = [
+        "map_and_batch_fusion.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":graph_utils",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/grappler:graph_view",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/clusters:cluster",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+    ] + tf_protos_all(),
+)
+
+tf_cc_test(
+    name = "map_and_batch_fusion_test",
+    srcs = ["map_and_batch_fusion_test.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":graph_utils",
+        ":map_and_batch_fusion",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/grappler:grappler_item",
+    ],
+)
+
+cc_library(
+    name = "data",
+    visibility = ["//visibility:public"],
+    deps = [
+        ":map_and_batch_fusion",
+    ],
+    alwayslink = 1,
+)
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
new file mode 100644 (file)
index 0000000..5b8df61
--- /dev/null
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
+                                   GraphDef* output) {
+  *output = item.graph;
+  GraphView graph(output);
+  std::set<string> nodes_to_delete;
+  for (const NodeDef& node : item.graph.node()) {
+    if (node.op() != "BatchDataset") {
+      continue;
+    }
+
+    // Use a more descriptive variable name now that we now the node type.
+    NodeDef batch_node(node);
+    GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
+    NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+    if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
+      continue;
+    }
+
+    // Use a more descriptive variable name now that we now the node type.
+    NodeDef* map_node = node2;
+    NodeDef* new_node = output->mutable_node()->Add();
+    new_node->set_op("MapAndBatchDatasetV2");
+    new_node->set_name(
+        strings::StrCat("MapAndBatchDatasetV2/_", output->node_size()));
+
+    // Set the `input` input argument.
+    new_node->add_input(map_node->input(0));
+
+    // Set the `other_arguments` input arguments.
+    int num_other_args;
+    if (map_node->op() == "ParallelMapDataset") {
+      num_other_args = map_node->input_size() - 2;
+    } else {
+      num_other_args = map_node->input_size() - 1;
+    }
+    for (int i = 0; i < num_other_args; i++) {
+      new_node->add_input(map_node->input(i + 1));
+    }
+
+    // Set the `batch_size` input argument.
+    new_node->add_input(batch_node.input(1));
+
+    // Set the `num_parallel_calls` input argument.
+    if (map_node->op() == "ParallelMapDataset") {
+      // The type of the `num_parallel_calls` argument in ParallelMapDataset
+      // and MapAndBatchDataset is different (int32 and int64 respectively)
+      // so we cannot reuse the same Const node and thus create a new one.
+      NodeDef* v = graph.GetNode(map_node->input(map_node->input_size() - 1));
+      NodeDef* tmp;
+      TF_RETURN_IF_ERROR(graph_utils::AddScalarConstNode<int64>(
+          v->attr().at("value").tensor().int_val(0), output, &tmp));
+      new_node->add_input(tmp->name());
+    } else {
+      NodeDef* tmp;
+      TF_RETURN_IF_ERROR(
+          graph_utils::AddScalarConstNode<int64>(1, output, &tmp));
+      new_node->add_input(tmp->name());
+    }
+
+    // Set the `drop_remainder` input argument.
+    {
+      NodeDef* tmp;
+      TF_RETURN_IF_ERROR(
+          graph_utils::AddScalarConstNode<bool>(false, output, &tmp));
+      new_node->add_input(tmp->name());
+    }
+
+    // Set `f` and `Targuments` attributes.
+    new_node->mutable_attr()->insert(map_node->attr().begin(),
+                                     map_node->attr().end());
+    // Set `output_types` and `output_shapes` attributes.
+    new_node->mutable_attr()->insert(batch_node.attr().begin(),
+                                     batch_node.attr().end());
+
+    // Mark the `Map` and `Batch` nodes for removal.
+    nodes_to_delete.insert(map_node->name());
+    nodes_to_delete.insert(batch_node.name());
+
+    // Update the input of the outputs of the `Batch` node to use
+    // `MapAndBatch`.
+    GraphView::OutputPort output_port =
+        graph.GetOutputPort(batch_node.name(), 0);
+    auto fanout = graph.GetFanout(output_port);
+    for (auto it = fanout.begin(); it != fanout.end(); ++it) {
+      NodeDef* node = it->node;
+      node->set_input(0, new_node->name());
+    }
+  }
+  TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
+  return Status::OK();
+}
+
+void MapAndBatchFusion::Feedback(Cluster* cluster, const GrapplerItem& item,
+                                 const GraphDef& optimize_output,
+                                 double result) {
+  // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchFusion, "map_and_batch_fusion");
+
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
new file mode 100644 (file)
index 0000000..a5a4d91
--- /dev/null
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class MapAndBatchFusion : public CustomGraphOptimizer {
+ public:
+  MapAndBatchFusion() {}
+  ~MapAndBatchFusion() override {}
+
+  string name() const override { return "map_and_batch_fusion"; };
+
+  Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
+                  nullptr) override {
+    return Status::OK();
+  }
+
+  Status Optimize(Cluster* cluster, const GrapplerItem& item,
+                  GraphDef* output) override;
+
+  void Feedback(Cluster* cluster, const GrapplerItem& item,
+                const GraphDef& optimize_output, double result) override;
+};
+
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
new file mode 100644 (file)
index 0000000..51e7f37
--- /dev/null
@@ -0,0 +1,184 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h"
+
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
+  std::vector<std::pair<string, AttrValue>> empty_attributes;
+
+  GrapplerItem item;
+  GraphDef *graph = &item.graph;
+  NodeDef *start_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+  NodeDef *stop_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+  NodeDef *step_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+  std::vector<string> range_inputs(3);
+  range_inputs[0] = start_node->name();
+  range_inputs[1] = stop_node->name();
+  range_inputs[2] = step_node->name();
+  NodeDef *range_node;
+  TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+                                    empty_attributes, graph, &range_node));
+  NodeDef *captured_input_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
+      "hello", graph, &captured_input_node));
+
+  std::vector<string> map_inputs(2);
+  map_inputs[0] = range_node->name();
+  map_inputs[1] = captured_input_node->name();
+  NodeDef *map_node;
+  TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs,
+                                    empty_attributes, graph, &map_node));
+
+  NodeDef *batch_size_node;
+  TF_ASSERT_OK(
+      graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+  std::vector<string> batch_inputs(2);
+  batch_inputs[0] = map_node->name();
+  batch_inputs[1] = batch_size_node->name();
+  NodeDef *batch_node;
+  TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
+                                    empty_attributes, graph, &batch_node));
+
+  MapAndBatchFusion optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
+  EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+  NodeDef map_and_batch_node =
+      output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+  EXPECT_EQ(map_and_batch_node.input_size(), 5);
+  EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
+  EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
+  EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
+  NodeDef num_parallel_calls_node = output.node(
+      graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+  EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0),
+            1);
+  NodeDef drop_remainder_node = output.node(
+      graph_utils::FindNodeWithName(map_and_batch_node.input(4), output));
+  EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false);
+}
+
+TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
+  std::vector<std::pair<string, AttrValue>> empty_attributes;
+
+  GrapplerItem item;
+  GraphDef *graph = &item.graph;
+  NodeDef *start_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+  NodeDef *stop_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+  NodeDef *step_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+  std::vector<string> range_inputs(3);
+  range_inputs[0] = start_node->name();
+  range_inputs[1] = stop_node->name();
+  range_inputs[2] = step_node->name();
+  NodeDef *range_node;
+  TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+                                    empty_attributes, graph, &range_node));
+  NodeDef *captured_input_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
+      "hello", graph, &captured_input_node));
+  NodeDef *num_parallel_calls_node;
+  TF_ASSERT_OK(
+      graph_utils::AddScalarConstNode<int>(2, graph, &num_parallel_calls_node));
+
+  std::vector<string> map_inputs(3);
+  map_inputs[0] = range_node->name();
+  map_inputs[1] = captured_input_node->name();
+  map_inputs[2] = num_parallel_calls_node->name();
+  NodeDef *map_node;
+  TF_ASSERT_OK(graph_utils::AddNode("", "ParallelMapDataset", map_inputs,
+                                    empty_attributes, graph, &map_node));
+
+  NodeDef *batch_size_node;
+  TF_ASSERT_OK(
+      graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+  std::vector<string> batch_inputs(2);
+  batch_inputs[0] = map_node->name();
+  batch_inputs[1] = batch_size_node->name();
+  NodeDef *batch_node;
+  TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
+                                    empty_attributes, graph, &batch_node));
+
+  MapAndBatchFusion optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
+  EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+  NodeDef map_and_batch_node =
+      output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+  EXPECT_EQ(map_and_batch_node.input_size(), 5);
+  EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
+  EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
+  EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
+  NodeDef num_parallel_calls_node2 = output.node(
+      graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+  EXPECT_EQ(num_parallel_calls_node2.attr().at("value").tensor().int64_val(0),
+            2);
+  NodeDef drop_remainder_node = output.node(
+      graph_utils::FindNodeWithName(map_and_batch_node.input(4), output));
+  EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false);
+}
+
+TEST(MapAndBatchFusionTest, NoChange) {
+  std::vector<std::pair<string, AttrValue>> empty_attributes;
+
+  GrapplerItem item;
+  GraphDef *graph = &item.graph;
+  NodeDef *start_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+  NodeDef *stop_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+  NodeDef *step_node;
+  TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+  std::vector<string> range_inputs(3);
+  range_inputs[0] = start_node->name();
+  range_inputs[1] = stop_node->name();
+  range_inputs[2] = step_node->name();
+  NodeDef *range_node;
+  TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+                                    empty_attributes, graph, &range_node));
+
+  MapAndBatchFusion optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  EXPECT_TRUE(graph_utils::Compare(*graph, output));
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow