From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Tue, 30 Jul 2019 06:18:07 +0000 (+0900) Subject: [moco-tf] Import as TFAvgPool with a knob (#6006) X-Git-Tag: submit/tizen/20190809.050447~332 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7ac7e65514e7cd6f5a3967e000b3d07081e16261;p=platform%2Fcore%2Fml%2Fnnfw.git [moco-tf] Import as TFAvgPool with a knob (#6006) This will add a Knob to import AvgPool node as TFAvgPool or AvgPool2D Signed-off-by: SaeHie Park --- diff --git a/compiler/moco-tf/src/Knob.lst b/compiler/moco-tf/src/Knob.lst index 45ad016..b241d52 100644 --- a/compiler/moco-tf/src/Knob.lst +++ b/compiler/moco-tf/src/Knob.lst @@ -5,6 +5,7 @@ // KNOB_BOOL(NAME, DEFAULT_VALUE, DESCRIPTION) // Imports +KNOB_BOOL(ImportAsTFAvgPool, false, Import AvgPool2D node as TFAvgPool node) KNOB_BOOL(ImportAsTFBiasAdd, true, Import BiasAdd node as TFBiasAdd node) KNOB_BOOL(ImportAsTFConst, true, Import Const node as TFConst node) KNOB_BOOL(ImportAsTFConv2D, true, Import Conv2D node as TFConv2D node) diff --git a/compiler/moco-tf/src/Op/AvgPool.cpp b/compiler/moco-tf/src/Op/AvgPool.cpp index 1fadc5d..043e2ae 100644 --- a/compiler/moco-tf/src/Op/AvgPool.cpp +++ b/compiler/moco-tf/src/Op/AvgPool.cpp @@ -14,9 +14,14 @@ * limitations under the License. */ +#include "AvgPool.h" + #include "Convert.h" #include "GraphBuilder.h" #include "GraphBuilderContext.h" +#include "Knob.h" + +#include "IR/TFAvgPool.h" #include "Annotations/PaddingData.h" @@ -37,10 +42,9 @@ namespace tf /** * @brief GraphBuilder for AvgPool node */ -class AvgPoolGraphBuilder final : public GraphBuilder +class AvgPoolGraphBuilder final : public AvgPoolGraphBuilderBase { public: - bool validate(const tensorflow::NodeDef &) const override; void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; }; @@ -59,7 +63,22 @@ private: const TensorName _input_name; }; -bool AvgPoolGraphBuilder::validate(const tensorflow::NodeDef &node) const +class TFAvgPoolGraphUpdate final : public GraphUpdate +{ +public: + TFAvgPoolGraphUpdate(moco::tf::TFAvgPool *node, const TensorName &name) + : _avgpool_node(node), _value_name(name) + { + } + + void input(const SymbolTable *) const override; + +private: + moco::tf::TFAvgPool *_avgpool_node; + const TensorName _value_name; +}; + +bool AvgPoolGraphBuilderBase::validate(const tensorflow::NodeDef &node) const { // note: even though "data_format" is not entered when a model is written, // TF seems to generate "data_format" field into a pb file @@ -74,10 +93,25 @@ bool AvgPoolGraphBuilder::validate(const tensorflow::NodeDef &node) const void AvgPoolGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { - using plier::tf::DataLayout; - assert(context != nullptr); + if (moco::tf::get()) + { + AvgPoolGraphBuilderImpl builder; + return builder.build(node, context); + } + else + { + AvgPoolGraphBuilderImpl builder; + return builder.build(node, context); + } +} + +void AvgPoolGraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ + using plier::tf::DataLayout; + loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); UpdateQueue *updates = context->updates(); @@ -214,6 +248,61 @@ void AvgPoolGraphUpdate::input(const SymbolTable *node_table) const _encode_node->input(input_node); } +void AvgPoolGraphBuilderImpl::build(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ + loco::Graph *graph = context->graph(); + SymbolTable *tensor_names = context->tensor_names(); + UpdateQueue *updates = context->updates(); + + // name of loco nodes + ::std::string avgPool2d_name = node.name(); + + // tensorflow data_format: one of NHWC or NCHW. + auto data_layout = get_string_attr(node, "data_format"); + auto avgPool_node = graph->nodes()->create(); + avgPool_node->data_layout(data_layout); + + // padding + auto padding = moco::str_toupper(get_string_attr(node, "padding")); + avgPool_node->padding(padding); + + // ksize + auto tf_ksize = get_list_attr(node, "ksize"); + auto ksize = as_int64_list(tf_ksize); + if (ksize.size() != 4) + { + // TODO support ksize length for 1 and 2 + throw std::runtime_error("AvgPool only supports ksize length 4"); + } + avgPool_node->ksize(ksize); + + // strides + auto tf_strides = get_list_attr(node, "strides"); + auto strides = as_int64_list(tf_strides); + if (strides.size() != 4) + { + // TODO support strides length for 1 and 2 + throw std::runtime_error("AvgPool only supports strides length 4"); + } + avgPool_node->strides(strides); + + // To set the input node of encode_node with avgPool2d_name + TensorName output_name(avgPool2d_name, 0); + tensor_names->enroll(output_name, avgPool_node); + + // Record ifm inputs to featureEncode_node + auto update = stdex::make_unique(avgPool_node, TensorName(node.input(0))); + + updates->enroll(std::move(update)); +} + +void TFAvgPoolGraphUpdate::input(const SymbolTable *node_table) const +{ + loco::Node *value_node = node_table->node(_value_name); + _avgpool_node->value(value_node); +} + } // namespace tf } // namespace moco diff --git a/compiler/moco-tf/src/Op/AvgPool.h b/compiler/moco-tf/src/Op/AvgPool.h new file mode 100644 index 0000000..ec9075a --- /dev/null +++ b/compiler/moco-tf/src/Op/AvgPool.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OP_AVG_POOL_H__ +#define __OP_AVG_POOL_H__ + +#include "GraphBuilder.h" +#include "ImportTarget.h" + +namespace moco +{ +namespace tf +{ + +struct AvgPoolGraphBuilderBase : public GraphBuilder +{ + virtual ~AvgPoolGraphBuilderBase() = default; + + bool validate(const tensorflow::NodeDef &) const final; +}; + +template class AvgPoolGraphBuilderImpl; + +template <> +struct AvgPoolGraphBuilderImpl final : public AvgPoolGraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +template <> +struct AvgPoolGraphBuilderImpl final : public AvgPoolGraphBuilderBase +{ + void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; +}; + +} // namespace tf +} // namespace moco + +#endif // __OP_AVG_POOL2D_H__ diff --git a/compiler/moco-tf/src/Op/AvgPool.test.cpp b/compiler/moco-tf/src/Op/AvgPool.test.cpp index cfc21df..bfa07c7 100644 --- a/compiler/moco-tf/src/Op/AvgPool.test.cpp +++ b/compiler/moco-tf/src/Op/AvgPool.test.cpp @@ -14,6 +14,10 @@ * limitations under the License. */ +#include "AvgPool.h" + +#include "IR/TFAvgPool.h" + #include "TestHelper.h" #include "Importer.h" @@ -27,6 +31,7 @@ #include +using namespace moco::tf; using namespace moco::tf::test; namespace @@ -125,7 +130,11 @@ TEST(TensorFlowImport, AvgPool_01) tensorflow::GraphDef graph_def; EXPECT_TRUE(plier::tf::parse_graphdef(avgpool_01_pbtxtdata, graph_def)); - std::unique_ptr graph = importer.import(signature, graph_def); + + // Test "AvgPoolGraphBuilderImpl" + { + // TODO: fix indentation + // clang-format off // what to test: // - there should exist AvgPool2D @@ -134,6 +143,14 @@ TEST(TensorFlowImport, AvgPool_01) // - stride values should match // - window values should match + using AvgPoolGraphBuilder = AvgPoolGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("AvgPool", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + loco::AvgPool2D *avgpool2d_node = moco::tf::test::find_first_node_bytype(graph.get()); ASSERT_NE(avgpool2d_node, nullptr); @@ -162,4 +179,37 @@ TEST(TensorFlowImport, AvgPool_01) // window ASSERT_EQ(avgpool2d->window()->vertical(), 2); ASSERT_EQ(avgpool2d->window()->horizontal(), 3); + // clang-format on + } + + // Test "AvgPoolGraphBuilderImpl" + { + // what to test: + // - there should exist TFAvgPool + // - attributes value should match + + using AvgPoolGraphBuilder = AvgPoolGraphBuilderImpl; + + moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; + r.add("AvgPool", stdex::make_unique()); + moco::tf::Importer importer{&r}; + + std::unique_ptr graph = importer.import(signature, graph_def); + + moco::tf::TFAvgPool *avgpool_node = + moco::tf::test::find_first_node_bytype(graph.get()); + ASSERT_NE(avgpool_node, nullptr); + + loco::Node *previous_node = avgpool_node->value(); + auto following_nodes = loco::succs(avgpool_node); + ASSERT_EQ(following_nodes.size(), 1); + loco::Node *following_node = *following_nodes.begin(); + ASSERT_NE(following_node, nullptr); + + // attrs inside TFAvgPool2D + ASSERT_EQ(avgpool_node->data_layout(), "NHWC"); + ASSERT_EQ(avgpool_node->padding(), "VALID"); + ASSERT_EQ(avgpool_node->ksize(), std::vector({1, 2, 3, 1})); + ASSERT_EQ(avgpool_node->strides(), std::vector({1, 3, 2, 1})); + } }