From 7a777d86cd3c66460dbe93279171092af8d868af Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 15 Jul 2019 09:56:04 +0900 Subject: [PATCH] [moco/tf] Introduce GraphBuilderSource interface (#4235) This commit extracts GraphBuilderSource from GraphBuilderRegistry which supports only lookup method. Signed-off-by: Jonghyun Park --- contrib/moco-tf/src/GraphBuilderRegistry.h | 14 ++++++++++++-- contrib/moco-tf/src/Importer.cpp | 9 ++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/contrib/moco-tf/src/GraphBuilderRegistry.h b/contrib/moco-tf/src/GraphBuilderRegistry.h index 3b44104..3ab5b16 100644 --- a/contrib/moco-tf/src/GraphBuilderRegistry.h +++ b/contrib/moco-tf/src/GraphBuilderRegistry.h @@ -28,17 +28,27 @@ namespace moco namespace tf { +struct GraphBuilderSource +{ + virtual ~GraphBuilderSource() = default; + + /** + * @brief Returns registered GraphBuilder pointer for operator (nullptr if not present) + */ + virtual const GraphBuilder *lookup(const std::string &op) const = 0; +}; + /** * @brief Class to return graph builder for TF nodes */ -class GraphBuilderRegistry +class GraphBuilderRegistry final : public GraphBuilderSource { public: /** * @brief Returns registered GraphBuilder pointer for operator or * nullptr if not registered */ - const GraphBuilder *lookup(const std::string &op) const + const GraphBuilder *lookup(const std::string &op) const final { if (_builder_map.find(op) == _builder_map.end()) return nullptr; diff --git a/contrib/moco-tf/src/Importer.cpp b/contrib/moco-tf/src/Importer.cpp index 7b93255..3a4d148 100644 --- a/contrib/moco-tf/src/Importer.cpp +++ b/contrib/moco-tf/src/Importer.cpp @@ -37,7 +37,8 @@ namespace { -void convert_graph(const moco::tf::ModelSignature &signature, tensorflow::GraphDef &tf_graph_def, +void convert_graph(const moco::tf::GraphBuilderSource &source, + const moco::tf::ModelSignature &signature, tensorflow::GraphDef &tf_graph_def, loco::Graph *graph) { auto nodedef = stdex::make_unique(); @@ -89,7 +90,7 @@ void convert_graph(const moco::tf::ModelSignature &signature, tensorflow::GraphD */ for (const auto &n : tf_graph_def.node()) { - if (const auto *graph_builder = moco::tf::GraphBuilderRegistry::get().lookup(n.op())) + if (const auto *graph_builder = source.lookup(n.op())) { if (!graph_builder->validate(n)) { @@ -267,7 +268,9 @@ std::unique_ptr Importer::import(const ModelSignature &signature, { auto graph = loco::make_graph(); - convert_graph(signature, tf_graph_def, graph.get()); + const GraphBuilderSource *source_ptr = &moco::tf::GraphBuilderRegistry::get(); + + convert_graph(*source_ptr, signature, tf_graph_def, graph.get()); transform_graph(graph.get()); -- 2.7.4