From 248f486eccbce087450edf476105801a3052bac3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 17 Jul 2019 13:06:48 +0900 Subject: [PATCH] [moco/tf] Remove Knob in Conv2D importer test (#4303) This commit removes Knob in Conv2D importer test. Each unittest run now tests both loco.TF and loco.Canonical importers at once. Signed-off-by: Jonghyun Park --- contrib/moco-tf/src/Op/Conv2D.cpp | 29 +++++++++-------- contrib/moco-tf/src/Op/Conv2D.h | 52 +++++++++++++++++++++++++++++++ contrib/moco-tf/src/Op/Conv2D.test.cpp | 57 +++++++++++++++++++++++++++++----- 3 files changed, 117 insertions(+), 21 deletions(-) create mode 100644 contrib/moco-tf/src/Op/Conv2D.h diff --git a/contrib/moco-tf/src/Op/Conv2D.cpp b/contrib/moco-tf/src/Op/Conv2D.cpp index df47a5a..c6d81e4 100644 --- a/contrib/moco-tf/src/Op/Conv2D.cpp +++ b/contrib/moco-tf/src/Op/Conv2D.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "Conv2D.h" + #include "Convert.h" #include "GraphBuilder.h" #include "GraphBuilderContext.h" @@ -116,18 +118,13 @@ namespace tf /** * @brief GraphBuilder for Conv2D node */ -class Conv2DGraphBuilder final : public GraphBuilder +class Conv2DGraphBuilder final : public Conv2DGraphBuilderBase { public: - bool validate(const tensorflow::NodeDef &) const override; void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; - -private: - void buildCanonical(const tensorflow::NodeDef &node, GraphBuilderContext *context) const; - void buildTF(const tensorflow::NodeDef &node, GraphBuilderContext *context) const; }; -bool Conv2DGraphBuilder::validate(const tensorflow::NodeDef &node) const +bool Conv2DGraphBuilderBase::validate(const tensorflow::NodeDef &node) const { assert(node.input_size() == 2); @@ -141,13 +138,19 @@ void Conv2DGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCont assert(context != nullptr); if (moco::tf::get()) - buildTF(node, context); + { + Conv2DGraphBuilderImpl builder; + builder.build(node, context); + } else - buildCanonical(node, context); + { + Conv2DGraphBuilderImpl builder; + builder.build(node, context); + } } -void Conv2DGraphBuilder::buildCanonical(const tensorflow::NodeDef &node, - GraphBuilderContext *context) const +void Conv2DGraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const { loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); @@ -266,8 +269,8 @@ void Conv2DGraphBuilder::buildCanonical(const tensorflow::NodeDef &node, updates->enroll(std::move(ker_update)); } -void Conv2DGraphBuilder::buildTF(const tensorflow::NodeDef &node, - GraphBuilderContext *context) const +void Conv2DGraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const { loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); diff --git a/contrib/moco-tf/src/Op/Conv2D.h b/contrib/moco-tf/src/Op/Conv2D.h new file mode 100644 index 0000000..e88b8e3 --- /dev/null +++ b/contrib/moco-tf/src/Op/Conv2D.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OP_CONV_2D_H__ +#define __OP_CONV_2D_H__ + +#include "GraphBuilder.h" +#include "ImportTarget.h" + +namespace moco +{ +namespace tf +{ + +struct Conv2DGraphBuilderBase : public GraphBuilder +{ + virtual ~Conv2DGraphBuilderBase() = default; + + bool validate(const tensorflow::NodeDef &) const final; +}; + +template class Conv2DGraphBuilderImpl; + +template <> +struct Conv2DGraphBuilderImpl final : public Conv2DGraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +template <> +struct Conv2DGraphBuilderImpl final : public Conv2DGraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +} // namespace tf +} // namespace moco + +#endif // __OP_CONV_2D_H__ diff --git a/contrib/moco-tf/src/Op/Conv2D.test.cpp b/contrib/moco-tf/src/Op/Conv2D.test.cpp index b971bbb..ce22b57 100644 --- a/contrib/moco-tf/src/Op/Conv2D.test.cpp +++ b/contrib/moco-tf/src/Op/Conv2D.test.cpp @@ -14,10 +14,11 @@ * limitations under the License. */ +#include "Conv2D.h" + #include "TestHelper.h" #include "Importer.h" -#include "Knob.h" #include "IR/TFConv2D.h" #include @@ -28,6 +29,7 @@ #include +using namespace moco::tf; using namespace moco::tf::test; namespace @@ -233,13 +235,32 @@ TEST(TensorFlowImport, Conv2D_01) tensorflow::GraphDef graph_def; EXPECT_TRUE(parse_graphdef(conv2d_01_pbtxtdata, graph_def)); - std::unique_ptr graph = importer.import(signature, graph_def); - // TODO remove using knob - if (moco::tf::get()) + // Test loco.TF Importer + { + using Conv2DGraphBuilder = Conv2DGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Conv2D", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + verify_TFConv2D_01(graph.get()); - else + } + + // Test loco.Canonical Importer + { + using Conv2DGraphBuilder = Conv2DGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Conv2D", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + verify_Conv2D_01(graph.get()); + } } namespace @@ -463,9 +484,29 @@ TEST(TensorFlowImport, Conv2D_inception_indexed_tensor_name) EXPECT_TRUE(parse_graphdef(conv2d_inception_pbtxtdata, graph_def)); std::unique_ptr graph = importer.import(signature, graph_def); - // TODO remove using knob - if (moco::tf::get()) + // Test loco.TF Importer + { + using Conv2DGraphBuilder = Conv2DGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Conv2D", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + verify_TFConv2D_inception_indexed_tensor_name(graph.get()); - else + } + + // Test loco.Canonical Importer + { + using Conv2DGraphBuilder = Conv2DGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Conv2D", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + verify_Conv2D_inception_indexed_tensor_name(graph.get()); + } } -- 2.7.4