From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Mon, 5 Aug 2019 04:10:04 +0000 (+0900) Subject: [moco-tf] Import as TFRelu6 (#6183) X-Git-Tag: submit/tizen/20190809.050447~192 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2750bbefd0a08987b81034ebc233fe40b382d3ae;p=platform%2Fcore%2Fml%2Fnnfw.git [moco-tf] Import as TFRelu6 (#6183) This will change import Relu6 node as TFRelu6 from the knob Signed-off-by: SaeHie Park --- diff --git a/compiler/moco-tf/src/Op/Relu6.cpp b/compiler/moco-tf/src/Op/Relu6.cpp index f4bad85..72c6f5c 100644 --- a/compiler/moco-tf/src/Op/Relu6.cpp +++ b/compiler/moco-tf/src/Op/Relu6.cpp @@ -14,7 +14,12 @@ * limitations under the License. */ +#include "Relu6.h" + #include "GraphBuilder.h" +#include "Knob.h" + +#include "IR/TFRelu6.h" #include @@ -25,10 +30,9 @@ namespace tf /** * @brief GraphBuilder for Relu6 node */ -class Relu6GraphBuilder final : public GraphBuilder +class Relu6GraphBuilder final : public Relu6GraphBuilderBase { public: - bool validate(const tensorflow::NodeDef &) const override; void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; }; @@ -44,7 +48,19 @@ private: const TensorName _name; }; -bool Relu6GraphBuilder::validate(const tensorflow::NodeDef &node) const +class TFRelu6GraphUpdate final : public GraphUpdate +{ +public: + TFRelu6GraphUpdate(moco::tf::TFRelu6 *node, const TensorName &&name) : _node(node), _name(name) {} + + void input(const SymbolTable *) const override; + +private: + moco::tf::TFRelu6 *_node; + const TensorName _name; +}; + +bool Relu6GraphBuilderBase::validate(const tensorflow::NodeDef &node) const { // ReLU6 node SHOULD have only one input if (node.input_size() != 1) @@ -56,6 +72,21 @@ void Relu6GraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte { assert(context != nullptr); + if (moco::tf::get()) + { + Relu6GraphBuilderImpl builder; + return builder.build(node, context); + } + else + { + Relu6GraphBuilderImpl builder; + return builder.build(node, context); + } +} + +void Relu6GraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); UpdateQueue *updates = context->updates(); @@ -72,12 +103,37 @@ void Relu6GraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte updates->enroll(std::move(update)); } +void Relu6GraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ + loco::Graph *graph = context->graph(); + SymbolTable *tensor_names = context->tensor_names(); + UpdateQueue *updates = context->updates(); + + // Create a "TFRelu6" node for Relu + auto relu_node = graph->nodes()->create(); + + // register string-name to node + TensorName output_name(node.name(), 0); + tensor_names->enroll(output_name, relu_node); + + // Queue node input update + auto update = stdex::make_unique(relu_node, TensorName(node.input(0))); + updates->enroll(std::move(update)); +} + void ReLU6GraphUpdate::input(const SymbolTable *table) const { loco::Node *target = table->node(_name); _node->input(target); } +void TFRelu6GraphUpdate::input(const SymbolTable *table) const +{ + loco::Node *target = table->node(_name); + _node->features(target); +} + } // namespace tf } // namespace moco diff --git a/compiler/moco-tf/src/Op/Relu6.h b/compiler/moco-tf/src/Op/Relu6.h new file mode 100644 index 0000000..8bbadee --- /dev/null +++ b/compiler/moco-tf/src/Op/Relu6.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OP_RELU6_H__ +#define __OP_RELU6_H__ + +#include "GraphBuilder.h" +#include "GraphBuilderContext.h" +#include "ImportTarget.h" + +namespace moco +{ +namespace tf +{ + +struct Relu6GraphBuilderBase : public GraphBuilder +{ + virtual ~Relu6GraphBuilderBase() = default; + + bool validate(const tensorflow::NodeDef &) const final; +}; + +template class Relu6GraphBuilderImpl; + +template <> +struct Relu6GraphBuilderImpl final : public Relu6GraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +template <> +struct Relu6GraphBuilderImpl final : public Relu6GraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +} // namespace tf +} // namespace moco + +#endif // __OP_RELU6_H__ diff --git a/compiler/moco-tf/src/Op/Relu6.test.cpp b/compiler/moco-tf/src/Op/Relu6.test.cpp index e0eb410..d0a4969 100644 --- a/compiler/moco-tf/src/Op/Relu6.test.cpp +++ b/compiler/moco-tf/src/Op/Relu6.test.cpp @@ -14,6 +14,10 @@ * limitations under the License. */ +#include "Relu6.h" + +#include "IR/TFRelu6.h" + #include "TestHelper.h" #include "Importer.h" @@ -25,6 +29,7 @@ #include +using namespace moco::tf; using namespace moco::tf::test; namespace @@ -81,6 +86,18 @@ TEST(TensorFlowImport, relu6_01) tensorflow::GraphDef graph_def; EXPECT_TRUE(plier::tf::parse_graphdef(relu6_01_pbtxtdata, graph_def)); + + // Test "Relu6GraphBuilderImpl" + { + // TODO: fix indentation + // clang-format off + + using ReluGraphBuilder = Relu6GraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Relu6", stdex::make_unique()); + moco::tf::Importer importer{&r}; + std::unique_ptr graph = importer.import(signature, graph_def); // what to test: @@ -91,4 +108,26 @@ TEST(TensorFlowImport, relu6_01) ASSERT_NE(relu6_node, nullptr); ASSERT_NE(relu6_node->input(), nullptr); + // clang-format on + } + + // Test "ReluGraphBuilderImpl" + { + using ReluGraphBuilder = Relu6GraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Relu6", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + + // what to test: + // - there should exist TFRelu6 + // - features node should not be null + + auto relu_node = moco::tf::test::find_first_node_bytype(graph.get()); + + ASSERT_NE(relu_node, nullptr); + ASSERT_NE(relu_node->features(), nullptr); + } }