From: Patrick Nguyen Date: Sat, 24 Feb 2018 00:04:38 +0000 (-0800) Subject: Add custom registered graph optimizers run by MetaOptimizer. X-Git-Tag: upstream/v1.7.0~31^2~396 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=beed05217cf8c3d90784a66cec7c97e042ff5258;p=platform%2Fupstream%2Ftensorflow.git Add custom registered graph optimizers run by MetaOptimizer. PiperOrigin-RevId: 186837828 --- diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e839630..50ba48e 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -158,6 +158,18 @@ cc_library( ) cc_library( + name = "custom_graph_optimizer", + hdrs = [ + "custom_graph_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "arithmetic_optimizer", srcs = ["arithmetic_optimizer.cc"], hdrs = [ @@ -368,6 +380,8 @@ cc_library( ":arithmetic_optimizer", ":auto_parallel", ":constant_folding", + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", ":dependency_optimizer", ":graph_optimizer", ":layout_optimizer", @@ -382,6 +396,48 @@ cc_library( ], ) +tf_cc_test( + name = "meta_optimizer_test", + srcs = ["meta_optimizer_test.cc"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":meta_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + +cc_library( + name = "custom_graph_optimizer_registry", + srcs = ["custom_graph_optimizer_registry.cc"], + hdrs = ["custom_graph_optimizer_registry.h"], + visibility = ["//visibility:public"], + deps = [ + ":custom_graph_optimizer", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "custom_graph_optimizer_registry_test", + size = "small", + srcs = ["custom_graph_optimizer_registry_test.cc"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "loop_optimizer", srcs = ["loop_optimizer.cc"], diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h new file mode 100644 index 0000000..a80d46f --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// A custom optimizer that can be registered. +class CustomGraphOptimizer : public GraphOptimizer { + public: + virtual ~CustomGraphOptimizer() {} + virtual Status Init() = 0; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc new file mode 100644 index 0000000..6eed43c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace grappler { + +namespace { +typedef std::unordered_map + RegistrationMap; +RegistrationMap* registered_optimizers = nullptr; +RegistrationMap* GetRegistrationMap() { + if (registered_optimizers == nullptr) + registered_optimizers = new RegistrationMap; + return registered_optimizers; +} +} // namespace + +std::unique_ptr +CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) { + const auto it = GetRegistrationMap()->find(name); + if (it == GetRegistrationMap()->end()) return nullptr; + return std::unique_ptr(it->second()); +} + +std::vector CustomGraphOptimizerRegistry::GetRegisteredOptimizers() { + std::vector optimizer_names; + optimizer_names.reserve(GetRegistrationMap()->size()); + for (const auto& opt : *GetRegistrationMap()) + optimizer_names.emplace_back(opt.first); + return optimizer_names; +} + +void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie( + const Creator& optimizer_creator, const string& name) { + const auto it = GetRegistrationMap()->find(name); + if (it != GetRegistrationMap()->end()) { + LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name; + } + GetRegistrationMap()->insert({name, optimizer_creator}); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h new file mode 100644 index 0000000..796da91 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +class CustomGraphOptimizerRegistry { + public: + static std::unique_ptr CreateByNameOrNull( + const string& name); + + static std::vector GetRegisteredOptimizers(); + + typedef std::function Creator; + // Regsiter graph optimizer which can be called during program initialization. + // This class is not thread-safe. + static void RegisterOptimizerOrDie(const Creator& optimizer_creator, + const string& name); +}; + +class CustomGraphOptimizerRegistrar { + public: + explicit CustomGraphOptimizerRegistrar( + const CustomGraphOptimizerRegistry::Creator& creator, + const string& name) { + CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(creator, name); + } +}; + +#define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \ + namespace { \ + static CustomGraphOptimizerRegistrar \ + MyCustomGraphOptimizerClass##_registrar( \ + []() { return new MyCustomGraphOptimizerClass; }, (name)); \ + } // namespace + +#define REGISTER_GRAPH_OPTIMIZER(MyCustomGraphOptimizerClass) \ + REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, \ + #MyCustomGraphOptimizerClass) + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc new file mode 100644 index 0000000..629f5e8 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +#include +#include +#include +#include + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +static const char* kTestOptimizerName = "Test"; + +class TestGraphOptimizer : public CustomGraphOptimizer { + public: + Status Init() override { return Status::OK(); } + string name() const override { return kTestOptimizerName; } + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + return Status::OK(); + } + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} +}; + +REGISTER_GRAPH_OPTIMIZER_AS(TestGraphOptimizer, "StaticRegister"); + +TEST(CustomGraphOptimizerRegistryTest, DynamicRegistration) { + std::vector optimizers = + CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); + std::unique_ptr test_optimizer; + ASSERT_EQ( + 0, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister")); + test_optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister"); + EXPECT_EQ(nullptr, test_optimizer); + CustomGraphOptimizerRegistry::RegisterOptimizerOrDie( + []() { return new TestGraphOptimizer; }, "DynamicRegister"); + optimizers = CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); + ASSERT_EQ( + 1, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister")); + test_optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister"); + ASSERT_NE(nullptr, test_optimizer); + EXPECT_EQ(kTestOptimizerName, test_optimizer->name()); +} + +TEST(CustomGraphOptimizerRegistryTest, StaticRegistration) { + const std::vector optimizers = + CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); + EXPECT_EQ(1, + std::count(optimizers.begin(), optimizers.end(), "StaticRegister")); + std::unique_ptr test_optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull("StaticRegister"); + ASSERT_NE(nullptr, test_optimizer); + EXPECT_EQ(kTestOptimizerName, test_optimizer->name()); +} + +TEST(GraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) { + const auto creator = []() { return new TestGraphOptimizer; }; + EXPECT_DEATH(CustomGraphOptimizerRegistry::RegisterOptimizerOrDie( + creator, "StaticRegister"), + "twice"); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index e27b9df..7ae7720 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" #include "tensorflow/core/grappler/optimizers/auto_parallel.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" @@ -126,14 +127,26 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, new AutoParallel(cfg_.auto_parallel().num_replicas()))); } } else { - std::set available_optimizers = { + const std::set available_optimizers = { "pruning", "constfold", "layout", "memory", "autoparallel", "arithmetic", "dependency", "loop"}; - for (const auto& optimizer : cfg_.optimizers()) { - if (available_optimizers.find(optimizer) != available_optimizers.end()) { - optimizers.push_back(NewOptimizer(optimizer)); + std::vector custom_optimizer_names; + for (const auto& optimizer_name : cfg_.optimizers()) { + if (available_optimizers.find(optimizer_name) != + available_optimizers.end()) { + optimizers.push_back(NewOptimizer(optimizer_name)); + } else { + custom_optimizer_names.push_back(optimizer_name); } } + // Now run the custom optimizers. + for (const auto& optimizer_name : custom_optimizer_names) { + std::unique_ptr opt = + CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name); + if (opt == nullptr) continue; + TF_RETURN_IF_ERROR(opt->Init()); + optimizers.push_back(std::move(opt)); + } } if (optimizers.empty()) { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc new file mode 100644 index 0000000..536347d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class TestOptimizer : public CustomGraphOptimizer { + public: + static void SetOptimized(const bool flag_value) { optimized_ = flag_value; } + static bool IsOptimized() { return optimized_; } + + TestOptimizer() {} + string name() const override { return "test_optimizer"; } + + Status Init() override { return Status::OK(); } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + optimized_ = true; + *optimized_graph = item.graph; + return Status::OK(); + } + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + static bool optimized_; +}; + +bool TestOptimizer::optimized_; + +REGISTER_GRAPH_OPTIMIZER(TestOptimizer); + +TEST(MetaOptimizerTest, RunsCustomOptimizer) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizer"); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index a61eeca..504ed5d 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -87,5 +87,8 @@ message RewriterConfig { // ("autoparallel"). Memory optimization passes ("memory") invoked here are // not configurable (in contrast to memory optimization passes through the // meta-optimizer) and act only on manual op annotations. + // + // Custom registered optimizers will be run after the base optimizers, in + // the order that they are specified. repeated string optimizers = 100; }