From be1b789be3b69833d24cc39766160988a6e60acb Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 19 Sep 2019 14:22:20 +0900 Subject: [PATCH] [exo-tflite] test case for FeatureBiasAddConverter (#7583) * [exo-tflite] test case for FeatureBiasAddConverter An unit test to check `FeatureBiasAddConverter` is added. Signed-off-by: Hyun Sik Yoon * fix typo and wrong index of constgen->at(1) --- .../Conversion/FeatureBiasAddConverter.test.cpp | 103 +++++++++++++++++++++ compiler/exo-tflite/src/TestGraph.h | 40 ++++++++ 2 files changed, 143 insertions(+) create mode 100644 compiler/exo-tflite/src/Conversion/FeatureBiasAddConverter.test.cpp diff --git a/compiler/exo-tflite/src/Conversion/FeatureBiasAddConverter.test.cpp b/compiler/exo-tflite/src/Conversion/FeatureBiasAddConverter.test.cpp new file mode 100644 index 0000000..b00d0be --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/FeatureBiasAddConverter.test.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FeatureBiasAddConverter.h" + +#include "GraphBlock.h" +#include "Dialect/IR/TFLNodes.h" + +#include "TestGraph.h" +#include "TestHelper.h" + +#include + +#include + +TEST(FeatureBiasAddConverterTest, basic_test) +{ + exo::test::ExampleGraph g; + g.build(); + + { // attrib setting + // pull + g.pull->dtype(loco::DataType::FLOAT32); + g.pull->shape({1, 2, 2, 3}); + + // bias value + g.constgen->dtype(loco::DataType::FLOAT32); + g.constgen->shape({3}); + g.constgen->size(3); + + g.constgen->at(0) = 0.5; + g.constgen->at(1) = 1; + g.constgen->at(2) = 1.5; + } + + EXO_TEST_ASSERT_NODE_COUNT({g.push}, 7); // sanity check + + // let's convert!! + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass(); + + test_phase.run(g.graph()); + + /* + Expected: + + Pull - FeatureEncoder - FeatureDecode - TFLAdd - FeatureEncode - FeatureDecode - Push + | + ConstGen - BiasEncode - BiasDecode ---+ + */ + } + + // check surroundings + auto tfl_add = exo::test::find_first_node_bytype(g.graph()); + { + ASSERT_TRUE(tfl_add != nullptr); + + // input x and its pred + { + auto actual_fea_dec = dynamic_cast(tfl_add->x()); + ASSERT_TRUE(actual_fea_dec != nullptr); + + auto actual_fea_enc = dynamic_cast(actual_fea_dec->input()); + ASSERT_TRUE(actual_fea_enc != nullptr); + ASSERT_TRUE(actual_fea_enc == g.fea_enc); + } + + // input y and its pred + { + auto actual_bias_dec = dynamic_cast(tfl_add->y()); + ASSERT_TRUE(actual_bias_dec != nullptr); + + auto actual_bias_enc = dynamic_cast(actual_bias_dec->input()); + ASSERT_TRUE(actual_bias_enc != nullptr); + ASSERT_TRUE(actual_bias_enc == g.bias_enc); + } + + // output check + { + auto actual_fea_enc = exo::test::get_only_succ(tfl_add); + ASSERT_TRUE(actual_fea_enc != nullptr); + + auto actual_fea_dec = exo::test::get_only_succ(actual_fea_enc); + ASSERT_TRUE(actual_fea_dec != nullptr); + ASSERT_TRUE(actual_fea_dec == g.fea_dec); + } + } +} diff --git a/compiler/exo-tflite/src/TestGraph.h b/compiler/exo-tflite/src/TestGraph.h index 867c515..0a00094 100644 --- a/compiler/exo-tflite/src/TestGraph.h +++ b/compiler/exo-tflite/src/TestGraph.h @@ -18,6 +18,7 @@ #define __TEST_GRAPH_H__ #include "Dialect/IR/TFLNodes.h" +#include "GraphBlock.h" #include @@ -143,6 +144,45 @@ private: loco::Node *_next_input; }; +enum class ExampleGraphType +{ + FeatureBiasAdd, +}; + +template class ExampleGraph; + +/** + * @brief Class to create the following: + * + * Pull - FeatureEncoder - FeatureBiasAdd - FeatureDecode - Push + * | + * ConstGen - BiasEncode --+ + */ +template <> class ExampleGraph : public TestGraph +{ +public: + loco::FeatureEncode *fea_enc = nullptr; + loco::ConstGen *constgen = nullptr; + loco::BiasEncode *bias_enc = nullptr; + loco::FeatureBiasAdd *fea_bias_add = nullptr; + loco::FeatureDecode *fea_dec = nullptr; + +public: + ExampleGraph() = default; + + loco::Graph *graph() { return g.get(); } + + void build() + { + fea_enc = exo::make_feature_encode(pull); + constgen = append(); + bias_enc = append(constgen); + fea_bias_add = append(fea_enc, bias_enc); + fea_dec = exo::make_feature_decode(fea_bias_add); + complete(fea_dec); + } +}; + } // namespace test } // namespace exo -- 2.7.4