Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseAddWithFullyConnectedPass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <luci/test/TestIOGraph.h>
22
23 #include <gtest/gtest.h>
24
25 namespace
26 {
27
28 using namespace luci::test;
29
30 // TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
31 template <typename T>
32 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
33                                      const std::vector<uint32_t> &shape,
34                                      const std::vector<T> &values)
35 {
36   auto node = g->nodes()->create<luci::CircleConst>();
37   node->dtype(dtype);
38   node->rank(shape.size());
39
40   uint32_t size = 1;
41   for (uint32_t i = 0; i < shape.size(); ++i)
42   {
43     node->dim(i) = shape.at(i);
44     size *= shape.at(i);
45   }
46   node->shape_status(luci::ShapeStatus::VALID);
47
48 #define INIT_VALUES(DT)                          \
49   {                                              \
50     node->size<DT>(size);                        \
51     for (uint32_t i = 0; i < values.size(); ++i) \
52       node->at<DT>(i) = values[i];               \
53   }
54
55   switch (dtype)
56   {
57     case loco::DataType::U8:
58       INIT_VALUES(loco::DataType::U8);
59       break;
60     case loco::DataType::S16:
61       INIT_VALUES(loco::DataType::S16);
62       break;
63     case loco::DataType::S32:
64       INIT_VALUES(loco::DataType::S32);
65       break;
66     case loco::DataType::FLOAT32:
67       INIT_VALUES(loco::DataType::FLOAT32)
68       break;
69     default:
70       INTERNAL_EXN("create_const_node called with unsupported type");
71       break;
72   }
73   return node;
74 }
75
76 /**
77  *  Simple graph for test
78  *
79  *  BEFORE
80  *
81  *         [FC]
82  *           |
83  *     [Add w/ Relu]
84  *
85  *  AFTER
86  *
87  *      [FC w/ Relu] (bias updated)
88  *
89  */
90 class FCAddGraphlet
91 {
92 public:
93   FCAddGraphlet() = default;
94
95   void init(loco::Graph *g)
96   {
97     std::vector<float> weights_val(16 * 4);
98     _fc_f = create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
99
100     std::vector<float> bias_val(16);
101     _fc_b = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
102
103     _fc = g->nodes()->create<luci::CircleFullyConnected>();
104     _fc->weights(_fc_f);
105     _fc->bias(_fc_b);
106     _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
107     _fc->dtype(loco::DataType::FLOAT32);
108     _fc->shape({1, 16});
109     _fc->name("fc");
110
111     std::vector<float> addition_val;
112     for (uint32_t i = 0; i < 16; i++)
113       addition_val.push_back(static_cast<float>(i));
114     _add_c = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
115
116     _add = g->nodes()->create<luci::CircleAdd>();
117     _add->x(_fc);
118     _add->y(_add_c);
119     _add->fusedActivationFunction(luci::FusedActFunc::RELU);
120     _add->dtype(loco::DataType::FLOAT32);
121     _add->shape({1, 16});
122     _add->name("add");
123   }
124
125 public:
126   luci::CircleFullyConnected *fc() { return _fc; }
127
128 protected:
129   luci::CircleFullyConnected *_fc = nullptr;
130   luci::CircleAdd *_add = nullptr;
131   luci::CircleConst *_fc_f = nullptr;
132   luci::CircleConst *_fc_b = nullptr;
133   luci::CircleConst *_add_c = nullptr;
134 };
135
136 class FuseAddWithFCTestGraph : public TestIOGraph, public FCAddGraphlet
137 {
138 public:
139   FuseAddWithFCTestGraph() = default;
140
141   void init(void)
142   {
143     TestIOGraph::init({1, 4}, {1, 16});
144     FCAddGraphlet::init(g());
145
146     _fc->input(input());
147
148     output()->from(_add);
149   }
150 };
151
152 class FuseAddWithFullyConnectedPassTest : public ::testing::Test
153 {
154 public:
155   FuseAddWithFCTestGraph g;
156   luci::FuseAddWithFullyConnectedPass pass;
157 };
158
159 } // namespace
160
161 TEST_F(FuseAddWithFullyConnectedPassTest, simple_test)
162 {
163   g.init();
164
165   auto ret = pass.run(g.g());
166   EXPECT_EQ(true, ret);
167
168   auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from());
169   EXPECT_NE(nullptr, fc);
170
171   auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias());
172   for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
173   {
174     EXPECT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
175   }
176 }