Add custom registered graph optimizers run by MetaOptimizer.
authorPatrick Nguyen <drpng@google.com>
Sat, 24 Feb 2018 00:04:38 +0000 (16:04 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 24 Feb 2018 00:08:55 +0000 (16:08 -0800)
PiperOrigin-RevId: 186837828

tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/custom_graph_optimizer.h [new file with mode: 0644]
tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc [new file with mode: 0644]
tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h [new file with mode: 0644]
tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc [new file with mode: 0644]
tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/core/grappler/optimizers/meta_optimizer_test.cc [new file with mode: 0644]
tensorflow/core/protobuf/rewriter_config.proto

index e839630..50ba48e 100644 (file)
@@ -158,6 +158,18 @@ cc_library(
 )
 
 cc_library(
+    name = "custom_graph_optimizer",
+    hdrs = [
+        "custom_graph_optimizer.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":graph_optimizer",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
     name = "arithmetic_optimizer",
     srcs = ["arithmetic_optimizer.cc"],
     hdrs = [
@@ -368,6 +380,8 @@ cc_library(
         ":arithmetic_optimizer",
         ":auto_parallel",
         ":constant_folding",
+        ":custom_graph_optimizer",
+        ":custom_graph_optimizer_registry",
         ":dependency_optimizer",
         ":graph_optimizer",
         ":layout_optimizer",
@@ -382,6 +396,48 @@ cc_library(
     ],
 )
 
+tf_cc_test(
+    name = "meta_optimizer_test",
+    srcs = ["meta_optimizer_test.cc"],
+    deps = [
+        ":custom_graph_optimizer",
+        ":custom_graph_optimizer_registry",
+        ":meta_optimizer",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:tensorflow",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+    ],
+)
+
+cc_library(
+    name = "custom_graph_optimizer_registry",
+    srcs = ["custom_graph_optimizer_registry.cc"],
+    hdrs = ["custom_graph_optimizer_registry.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":custom_graph_optimizer",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "custom_graph_optimizer_registry_test",
+    size = "small",
+    srcs = ["custom_graph_optimizer_registry_test.cc"],
+    deps = [
+        ":custom_graph_optimizer",
+        ":custom_graph_optimizer_registry",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
+
 cc_library(
     name = "loop_optimizer",
     srcs = ["loop_optimizer.cc"],
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h
new file mode 100644 (file)
index 0000000..a80d46f
--- /dev/null
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_
+#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_
+
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// A custom optimizer that can be registered.
+class CustomGraphOptimizer : public GraphOptimizer {
+ public:
+  virtual ~CustomGraphOptimizer() {}
+  virtual Status Init() = 0;
+};
+
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc
new file mode 100644 (file)
index 0000000..6eed43c
--- /dev/null
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace grappler {
+
+namespace {
+typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
+    RegistrationMap;
+RegistrationMap* registered_optimizers = nullptr;
+RegistrationMap* GetRegistrationMap() {
+  if (registered_optimizers == nullptr)
+    registered_optimizers = new RegistrationMap;
+  return registered_optimizers;
+}
+}  // namespace
+
+std::unique_ptr<CustomGraphOptimizer>
+CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) {
+  const auto it = GetRegistrationMap()->find(name);
+  if (it == GetRegistrationMap()->end()) return nullptr;
+  return std::unique_ptr<CustomGraphOptimizer>(it->second());
+}
+
+std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() {
+  std::vector<string> optimizer_names;
+  optimizer_names.reserve(GetRegistrationMap()->size());
+  for (const auto& opt : *GetRegistrationMap())
+    optimizer_names.emplace_back(opt.first);
+  return optimizer_names;
+}
+
+void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
+    const Creator& optimizer_creator, const string& name) {
+  const auto it = GetRegistrationMap()->find(name);
+  if (it != GetRegistrationMap()->end()) {
+    LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name;
+  }
+  GetRegistrationMap()->insert({name, optimizer_creator});
+}
+
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
new file mode 100644 (file)
index 0000000..796da91
--- /dev/null
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class CustomGraphOptimizerRegistry {
+ public:
+  static std::unique_ptr<CustomGraphOptimizer> CreateByNameOrNull(
+      const string& name);
+
+  static std::vector<string> GetRegisteredOptimizers();
+
+  typedef std::function<CustomGraphOptimizer*()> Creator;
+  // Regsiter graph optimizer which can be called during program initialization.
+  // This class is not thread-safe.
+  static void RegisterOptimizerOrDie(const Creator& optimizer_creator,
+                                     const string& name);
+};
+
+class CustomGraphOptimizerRegistrar {
+ public:
+  explicit CustomGraphOptimizerRegistrar(
+      const CustomGraphOptimizerRegistry::Creator& creator,
+      const string& name) {
+    CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(creator, name);
+  }
+};
+
+#define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \
+  namespace {                                                          \
+  static CustomGraphOptimizerRegistrar                                 \
+      MyCustomGraphOptimizerClass##_registrar(                         \
+          []() { return new MyCustomGraphOptimizerClass; }, (name));   \
+  }  // namespace
+
+#define REGISTER_GRAPH_OPTIMIZER(MyCustomGraphOptimizerClass) \
+  REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass,    \
+                              #MyCustomGraphOptimizerClass)
+
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc
new file mode 100644 (file)
index 0000000..629f5e8
--- /dev/null
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+static const char* kTestOptimizerName = "Test";
+
+class TestGraphOptimizer : public CustomGraphOptimizer {
+ public:
+  Status Init() override { return Status::OK(); }
+  string name() const override { return kTestOptimizerName; }
+  Status Optimize(Cluster* cluster, const GrapplerItem& item,
+                  GraphDef* optimized_graph) override {
+    return Status::OK();
+  }
+  void Feedback(Cluster* cluster, const GrapplerItem& item,
+                const GraphDef& optimized_graph, double result) override {}
+};
+
+REGISTER_GRAPH_OPTIMIZER_AS(TestGraphOptimizer, "StaticRegister");
+
+TEST(CustomGraphOptimizerRegistryTest, DynamicRegistration) {
+  std::vector<string> optimizers =
+      CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
+  std::unique_ptr<const CustomGraphOptimizer> test_optimizer;
+  ASSERT_EQ(
+      0, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister"));
+  test_optimizer =
+      CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister");
+  EXPECT_EQ(nullptr, test_optimizer);
+  CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
+      []() { return new TestGraphOptimizer; }, "DynamicRegister");
+  optimizers = CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
+  ASSERT_EQ(
+      1, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister"));
+  test_optimizer =
+      CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister");
+  ASSERT_NE(nullptr, test_optimizer);
+  EXPECT_EQ(kTestOptimizerName, test_optimizer->name());
+}
+
+TEST(CustomGraphOptimizerRegistryTest, StaticRegistration) {
+  const std::vector<string> optimizers =
+      CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
+  EXPECT_EQ(1,
+            std::count(optimizers.begin(), optimizers.end(), "StaticRegister"));
+  std::unique_ptr<const CustomGraphOptimizer> test_optimizer =
+      CustomGraphOptimizerRegistry::CreateByNameOrNull("StaticRegister");
+  ASSERT_NE(nullptr, test_optimizer);
+  EXPECT_EQ(kTestOptimizerName, test_optimizer->name());
+}
+
+TEST(GraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) {
+  const auto creator = []() { return new TestGraphOptimizer; };
+  EXPECT_DEATH(CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
+                   creator, "StaticRegister"),
+               "twice");
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow
index e27b9df..7ae7720 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
@@ -126,14 +127,26 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
           new AutoParallel(cfg_.auto_parallel().num_replicas())));
     }
   } else {
-    std::set<string> available_optimizers = {
+    const std::set<string> available_optimizers = {
         "pruning",      "constfold",  "layout",     "memory",
         "autoparallel", "arithmetic", "dependency", "loop"};
-    for (const auto& optimizer : cfg_.optimizers()) {
-      if (available_optimizers.find(optimizer) != available_optimizers.end()) {
-        optimizers.push_back(NewOptimizer(optimizer));
+    std::vector<string> custom_optimizer_names;
+    for (const auto& optimizer_name : cfg_.optimizers()) {
+      if (available_optimizers.find(optimizer_name) !=
+          available_optimizers.end()) {
+        optimizers.push_back(NewOptimizer(optimizer_name));
+      } else {
+        custom_optimizer_names.push_back(optimizer_name);
       }
     }
+    // Now run the custom optimizers.
+    for (const auto& optimizer_name : custom_optimizer_names) {
+      std::unique_ptr<CustomGraphOptimizer> opt =
+          CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
+      if (opt == nullptr) continue;
+      TF_RETURN_IF_ERROR(opt->Init());
+      optimizers.push_back(std::move(opt));
+    }
   }
 
   if (optimizers.empty()) {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
new file mode 100644 (file)
index 0000000..536347d
--- /dev/null
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class TestOptimizer : public CustomGraphOptimizer {
+ public:
+  static void SetOptimized(const bool flag_value) { optimized_ = flag_value; }
+  static bool IsOptimized() { return optimized_; }
+
+  TestOptimizer() {}
+  string name() const override { return "test_optimizer"; }
+
+  Status Init() override { return Status::OK(); }
+
+  Status Optimize(Cluster* cluster, const GrapplerItem& item,
+                  GraphDef* optimized_graph) override {
+    optimized_ = true;
+    *optimized_graph = item.graph;
+    return Status::OK();
+  }
+
+  void Feedback(Cluster* cluster, const GrapplerItem& item,
+                const GraphDef& optimized_graph, double result) override {}
+
+ private:
+  static bool optimized_;
+};
+
+bool TestOptimizer::optimized_;
+
+REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
+
+TEST(MetaOptimizerTest, RunsCustomOptimizer) {
+  TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+
+  TestOptimizer::SetOptimized(false);
+  RewriterConfig rewriter_config;
+  rewriter_config.add_optimizers("TestOptimizer");
+
+  MetaOptimizer optimizer(nullptr, rewriter_config);
+  GraphDef output;
+  const Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  EXPECT_TRUE(TestOptimizer::IsOptimized());
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow
index a61eeca..504ed5d 100644 (file)
@@ -87,5 +87,8 @@ message RewriterConfig {
   // ("autoparallel"). Memory optimization passes ("memory") invoked here are
   // not configurable (in contrast to memory optimization passes through the
   // meta-optimizer) and act only on manual op annotations.
+  //
+  // Custom registered optimizers will be run after the base optimizers, in
+  // the order that they are specified.
   repeated string optimizers = 100;
 }