93024f3f77243ec9565cb992fc27d6371797109a
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ReplaceNonConstFCWithBatchMatMulPass.test.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
18
19 #include <luci/test/TestIOGraph.h>
20 #include <luci/IR/CircleNodes.h>
21
22 #include <gtest/gtest.h>
23
24 namespace
25 {
26
27 using namespace luci::test;
28
29 // TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
30 template <typename T>
31 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
32                                      const std::vector<uint32_t> &shape,
33                                      const std::vector<T> &values)
34 {
35   auto node = g->nodes()->create<luci::CircleConst>();
36   node->dtype(dtype);
37   node->rank(shape.size());
38
39   uint32_t size = 1;
40   for (uint32_t i = 0; i < shape.size(); ++i)
41   {
42     node->dim(i) = shape.at(i);
43     size *= shape.at(i);
44   }
45   node->shape_status(luci::ShapeStatus::VALID);
46
47 #define INIT_VALUES(DT)                          \
48   {                                              \
49     node->size<DT>(size);                        \
50     for (uint32_t i = 0; i < values.size(); ++i) \
51       node->at<DT>(i) = values[i];               \
52   }
53
54   switch (dtype)
55   {
56     case loco::DataType::U8:
57       INIT_VALUES(loco::DataType::U8);
58       break;
59     case loco::DataType::S16:
60       INIT_VALUES(loco::DataType::S16);
61       break;
62     case loco::DataType::S32:
63       INIT_VALUES(loco::DataType::S32);
64       break;
65     case loco::DataType::FLOAT32:
66       INIT_VALUES(loco::DataType::FLOAT32)
67       break;
68     default:
69       INTERNAL_EXN("create_const_node called with unsupported type");
70       break;
71   }
72   return node;
73 }
74
75 /**
76  *  Simple graph for test
77  *
78  *  BEFORE
79  *
80  *   [IFM1] [IFM2] [BIAS]
81  *        \   |   /
82  *          [FC]
83  *            |
84  *          [Res]
85  *
86  *  AFTER
87  *   [IFM1] [IFM2]
88  *        \   |
89  *      [BatchMatMul] [BIAS]
90  *              \      /
91  *               [Add]
92  *                 |
93  *               [Res]
94  *
95  */
96 struct FCGraphlet
97 {
98 public:
99   FCGraphlet() = default;
100   virtual ~FCGraphlet() = default;
101
102   void init(loco::Graph *g, const ShapeU32 r_shape, const float bv)
103   {
104     _tr_y = g->nodes()->create<luci::CircleTranspose>();
105     _tr_y->a(_y);
106     std::vector<int32_t> tr_val = {1, 0};
107     _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val));
108
109     _fc = g->nodes()->create<luci::CircleFullyConnected>();
110     _fc->input(_x);
111     _fc->weights(_tr_y);
112     _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
113     _fc->dtype(loco::DataType::FLOAT32);
114     _fc->shape(r_shape);
115     auto l = _fc->dim(_fc->rank() - 1).value();
116     std::vector<float> bias_val(l, bv);
117     _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
118     _fc->name("fc");
119   }
120
121 public:
122   luci::CircleFullyConnected *fc() { return _fc; }
123
124 protected:
125   luci::CircleFullyConnected *_fc = nullptr;
126   luci::CircleTranspose *_tr_y = nullptr;
127   luci::CircleInput *_x = nullptr;
128   luci::CircleInput *_y = nullptr;
129 };
130
131 struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet
132 {
133   FCGraph() = default;
134   virtual ~FCGraph() = default;
135   void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape, const float bv)
136   {
137     TestIsGraphlet<2>::init(g(), {x_shape, y_shape});
138     TestOGraphlet::init(g(), r_shape);
139     _x = input(0);
140     _y = input(1);
141     FCGraphlet::init(g(), r_shape, bv);
142     output()->from(_fc);
143   }
144 };
145
146 class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test
147 {
148 public:
149   FCGraph g;
150   luci::ReplaceNonConstFCWithBatchMatMulPass pass;
151 };
152
153 } // namespace
154
155 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test)
156 {
157   g.init({2, 3}, {2, 3}, {2, 2}, 0.0f);
158
159   auto ret = pass.run(g.g());
160   EXPECT_EQ(true, ret);
161
162   auto res = dynamic_cast<luci::CircleReshape *>(g.output()->from());
163   EXPECT_NE(nullptr, res);
164 }
165
166 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test)
167 {
168   g.init({2, 3}, {2, 3}, {2, 2}, 1.0f);
169
170   auto ret = pass.run(g.g());
171   EXPECT_EQ(true, ret);
172
173   auto mm = dynamic_cast<luci::CircleAdd *>(g.output()->from());
174   EXPECT_NE(nullptr, mm);
175 }
176
177 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG)
178 {
179   loco::Graph g;
180
181   auto inp = g.nodes()->create<luci::CircleInput>();
182   auto relu = g.nodes()->create<luci::CircleRelu>();
183   relu->features(inp);
184
185   luci::ReplaceNonConstFCWithBatchMatMulPass pass;
186   auto changed = pass.run(&g);
187
188   EXPECT_EQ(false, changed);
189 }