From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Mon, 5 Aug 2019 04:38:11 +0000 (+0900) Subject: [moco-tf] Import as TFRelu (#6182) X-Git-Tag: submit/tizen/20190809.050447~189 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c291155f5c698ce6a654b96c6245d45da7a575aa;p=platform%2Fcore%2Fml%2Fnnfw.git [moco-tf] Import as TFRelu (#6182) This will change import Relu node as TFRelu from the knob Signed-off-by: SaeHie Park --- diff --git a/compiler/moco-tf/src/Op/Relu.cpp b/compiler/moco-tf/src/Op/Relu.cpp index 78a2174..c3a6a82 100644 --- a/compiler/moco-tf/src/Op/Relu.cpp +++ b/compiler/moco-tf/src/Op/Relu.cpp @@ -14,8 +14,13 @@ * limitations under the License. */ +#include "Relu.h" + #include "GraphBuilder.h" #include "GraphBuilderContext.h" +#include "Knob.h" + +#include "IR/TFRelu.h" #include #include @@ -34,10 +39,9 @@ namespace tf /** * @brief GraphBuilder for Relu node */ -class ReluGraphBuilder final : public GraphBuilder +class ReluGraphBuilder final : public ReluGraphBuilderBase { public: - bool validate(const tensorflow::NodeDef &) const override; void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; }; @@ -53,12 +57,46 @@ private: const TensorName _name; }; -bool ReluGraphBuilder::validate(const tensorflow::NodeDef &node) const { return true; } +class TFReluGraphUpdate final : public GraphUpdate +{ +public: + TFReluGraphUpdate(moco::tf::TFRelu *node, const TensorName &&name) : _node(node), _name(name) {} + + void input(const SymbolTable *) const override; + +private: + moco::tf::TFRelu *_node; + const TensorName _name; +}; + +bool ReluGraphBuilderBase::validate(const tensorflow::NodeDef &node) const +{ + // ReLU node SHOULD have only one input + if (node.input_size() != 1) + return false; + + return true; +} void ReluGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { assert(context != nullptr); + if (moco::tf::get()) + { + ReluGraphBuilderImpl builder; + return builder.build(node, context); + } + else + { + ReluGraphBuilderImpl builder; + return builder.build(node, context); + } +} + +void ReluGraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); UpdateQueue *updates = context->updates(); @@ -71,18 +109,41 @@ void ReluGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContex tensor_names->enroll(output_name, relu_node); // Queue node input update - // ReLU node SHOULD have only one input - assert(node.input_size() == 1); auto update = stdex::make_unique(relu_node, TensorName(node.input(0))); updates->enroll(std::move(update)); } +void ReluGraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ + loco::Graph *graph = context->graph(); + SymbolTable *tensor_names = context->tensor_names(); + UpdateQueue *updates = context->updates(); + + // Create a "TFRelu" node for Relu + auto relu_node = graph->nodes()->create(); + + // register string-name to node + TensorName output_name(node.name(), 0); + tensor_names->enroll(output_name, relu_node); + + // Queue node input update + auto update = stdex::make_unique(relu_node, TensorName(node.input(0))); + updates->enroll(std::move(update)); +} + void ReLUGraphUpdate::input(const SymbolTable *table) const { loco::Node *target = table->node(_name); _node->input(target); } +void TFReluGraphUpdate::input(const SymbolTable *table) const +{ + loco::Node *target = table->node(_name); + _node->features(target); +} + } // namespace tf } // namespace moco diff --git a/compiler/moco-tf/src/Op/Relu.h b/compiler/moco-tf/src/Op/Relu.h new file mode 100644 index 0000000..7d75f8a --- /dev/null +++ b/compiler/moco-tf/src/Op/Relu.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OP_RELU_H__ +#define __OP_RELU_H__ + +#include "GraphBuilder.h" +#include "ImportTarget.h" + +namespace moco +{ +namespace tf +{ + +struct ReluGraphBuilderBase : public GraphBuilder +{ + virtual ~ReluGraphBuilderBase() = default; + + bool validate(const tensorflow::NodeDef &) const final; +}; + +template class ReluGraphBuilderImpl; + +template <> struct ReluGraphBuilderImpl final : public ReluGraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +template <> +struct ReluGraphBuilderImpl final : public ReluGraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +} // namespace tf +} // namespace moco + +#endif // __OP_RELU_H__ diff --git a/compiler/moco-tf/src/Op/Relu.test.cpp b/compiler/moco-tf/src/Op/Relu.test.cpp index b7b1861..bdd1152 100644 --- a/compiler/moco-tf/src/Op/Relu.test.cpp +++ b/compiler/moco-tf/src/Op/Relu.test.cpp @@ -14,6 +14,10 @@ * limitations under the License. */ +#include "Relu.h" + +#include "IR/TFRelu.h" + #include "TestHelper.h" #include "Importer.h" @@ -25,6 +29,7 @@ #include +using namespace moco::tf; using namespace moco::tf::test; namespace @@ -81,6 +86,18 @@ TEST(TensorFlowImport, relu_01) tensorflow::GraphDef graph_def; EXPECT_TRUE(plier::tf::parse_graphdef(relu_01_pbtxtdata, graph_def)); + + // Test "ReluGraphBuilderImpl" + { + // TODO: fix indentation + // clang-format off + + using ReluGraphBuilder = ReluGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Relu", stdex::make_unique()); + moco::tf::Importer importer{&r}; + std::unique_ptr graph = importer.import(signature, graph_def); // what to test: @@ -91,4 +108,26 @@ TEST(TensorFlowImport, relu_01) ASSERT_NE(relu_node, nullptr); ASSERT_NE(relu_node->input(), nullptr); + // clang-format on + } + + // Test "ReluGraphBuilderImpl" + { + using ReluGraphBuilder = ReluGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("Relu", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + + // what to test: + // - there should exist TFRelu + // - features node should not be nullptr + + auto relu_node = moco::tf::test::find_first_node_bytype(graph.get()); + + ASSERT_NE(relu_node, nullptr); + ASSERT_NE(relu_node->features(), nullptr); + } }