From: 윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Thu, 19 Sep 2019 06:06:17 +0000 (+0900) Subject: [exo-tflite] adding converter for loco::ConstGen (#7632) X-Git-Tag: submit/tizen/20191205.083104~1151 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5964314e7308e66e3fa777730f33f68baea1a03a;p=platform%2Fcore%2Fml%2Fnnfw.git [exo-tflite] adding converter for loco::ConstGen (#7632) converter and its test for loco::ConstGen (to locoex::TFLConst) was added. Signed-off-by: Hyun Sik Yoon --- diff --git a/compiler/exo-tflite/src/Conversion/CanonicalNodeConverter.cpp b/compiler/exo-tflite/src/Conversion/CanonicalNodeConverter.cpp index 25dc080..e4396c2 100644 --- a/compiler/exo-tflite/src/Conversion/CanonicalNodeConverter.cpp +++ b/compiler/exo-tflite/src/Conversion/CanonicalNodeConverter.cpp @@ -46,7 +46,7 @@ bool CanonicalNodeConverter::run(loco::Graph *graph) // template instantiation template bool CanonicalNodeConverter::run(loco::Graph *graph); -// TODO loco::ConstGen +template bool CanonicalNodeConverter::run(loco::Graph *graph); // TODO loco::Conv2D // TODO loco::DepthwiseConv2D // TODO loco::DepthwiseFilterEncode diff --git a/compiler/exo-tflite/src/Conversion/ConstGenConverter.cpp b/compiler/exo-tflite/src/Conversion/ConstGenConverter.cpp new file mode 100644 index 0000000..e576d17 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/ConstGenConverter.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConstGenConverter.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Check.h" + +#include + +namespace exo +{ + +bool ConstGenConverter::convert(loco::ConstGen *constgen) +{ + auto *graph = constgen->graph(); + + auto tfl_const = graph->nodes()->create(); + { + if (constgen->dtype() == loco::DataType::FLOAT32) + { + tfl_const->dtype(loco::DataType::FLOAT32); + + tfl_const->rank(constgen->rank()); + for (uint32_t axis = 0; axis < constgen->rank(); axis++) + tfl_const->dim(axis) = constgen->dim(axis); + + auto size = constgen->size(); + tfl_const->size(size); + + for (uint32_t i = 0; i < size; ++i) + { + tfl_const->at(i) = constgen->at(i); + } + } + else + EXO_THROW("Unsupported DataType"); + } + + loco::replace(constgen).with(tfl_const); + + return true; +} + +} // namespace exo diff --git a/compiler/exo-tflite/src/Conversion/ConstGenConverter.h b/compiler/exo-tflite/src/Conversion/ConstGenConverter.h new file mode 100644 index 0000000..613ccd0 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/ConstGenConverter.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_CONSTGEN_CONVERTER_H__ +#define __CONVERSION_CONSTGEN_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include + +namespace exo +{ + +class ConstGenConverter : public CanonicalNodeConverter +{ +public: + const char *name(void) const final { return "exo::ConstGenConverter"; } + +public: + bool convert(loco::ConstGen *constgen) final; +}; + +} // namespace exo + +#endif // __CONVERSION_CONSTGEN_CONVERTER_H__ diff --git a/compiler/exo-tflite/src/Conversion/ConstGenConverter.test.cpp b/compiler/exo-tflite/src/Conversion/ConstGenConverter.test.cpp new file mode 100644 index 0000000..9d46545 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/ConstGenConverter.test.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConstGenConverter.h" +#include "ReluConverter.h" + +#include "Dialect/IR/TFLNodes.h" +#include "TestGraph.h" +#include "TestHelper.h" + +#include + +#include + +TEST(TFLConstGenConverterTest, ConstGen_Relu) +{ + exo::test::ExampleGraph g; + g.build(); + + // set constgen + { + g.constgen->dtype(loco::DataType::FLOAT32); + g.constgen->shape({2, 1}); + g.constgen->size(2); + + g.constgen->at(0) = 0.5; + g.constgen->at(1) = -0.5; + } + + // let's convert + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass(); + test_phase.add_pass(); + + test_phase.run(g.graph()); + } + + auto tfl_const = exo::test::find_first_node_bytype(g.graph()); + auto tfl_relu = exo::test::find_first_node_bytype(g.graph()); + + ASSERT_TRUE(tfl_const != nullptr and tfl_relu != nullptr); + ASSERT_TRUE(tfl_relu->features() == tfl_const); + + ASSERT_TRUE(tfl_const->rank() == g.constgen->rank()); + ASSERT_TRUE(tfl_const->dim(0) == g.constgen->dim(0)); + ASSERT_TRUE(tfl_const->dim(1) == g.constgen->dim(1)); + ASSERT_TRUE(tfl_const->at(0) == + g.constgen->at(0)); + ASSERT_TRUE(tfl_const->at(1) == + g.constgen->at(1)); +} diff --git a/compiler/exo-tflite/src/TestGraph.h b/compiler/exo-tflite/src/TestGraph.h index 0a00094..f2f6fef 100644 --- a/compiler/exo-tflite/src/TestGraph.h +++ b/compiler/exo-tflite/src/TestGraph.h @@ -19,6 +19,7 @@ #include "Dialect/IR/TFLNodes.h" #include "GraphBlock.h" +#include "TestHelper.h" #include @@ -147,6 +148,7 @@ private: enum class ExampleGraphType { FeatureBiasAdd, + ConstGen_ReLU }; template class ExampleGraph; @@ -183,6 +185,32 @@ public: } }; +/** + * @brief Class to creates the following: + * + * ConstGen -- ReLU -- Push + */ +template <> class ExampleGraph : public TestGraph +{ +public: + loco::ConstGen *constgen = nullptr; + loco::ReLU *relu = nullptr; + +public: + ExampleGraph() = default; + + loco::Graph *graph() { return g.get(); } + + void build() + { + constgen = append(); + relu = append(constgen); + complete(relu); + + EXO_TEST_ASSERT_NODE_COUNT({push}, 3); // sanity check + } +}; + } // namespace test } // namespace exo