Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseGeluPass.test.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseGeluPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <luci/test/TestIOGraph.h>
22
23 #include <cmath>
24 #include <gtest/gtest.h>
25
26 namespace
27 {
28
29 using namespace luci::test;
30
31 class GeluGraphlet
32 {
33 public:
34   GeluGraphlet() = default;
35
36   void init(loco::Graph *g)
37   {
38     _ifm = g->nodes()->create<luci::CircleAbs>();
39     _mul_sqrt = g->nodes()->create<luci::CircleMul>();
40     _erf = g->nodes()->create<luci::CircleCustom>(1, 1);
41     _erf_out = g->nodes()->create<luci::CircleCustomOut>();
42     _add_one = g->nodes()->create<luci::CircleAdd>();
43     _mul = g->nodes()->create<luci::CircleMul>();
44     _mul_half = g->nodes()->create<luci::CircleMul>();
45     _const_sqrt = g->nodes()->create<luci::CircleConst>();
46     _const_one = g->nodes()->create<luci::CircleConst>();
47     _const_half = g->nodes()->create<luci::CircleConst>();
48
49     _mul->fusedActivationFunction(luci::FusedActFunc::NONE);
50     _mul_sqrt->fusedActivationFunction(luci::FusedActFunc::NONE);
51     _mul_half->fusedActivationFunction(luci::FusedActFunc::NONE);
52     _add_one->fusedActivationFunction(luci::FusedActFunc::NONE);
53
54     _ifm->name("ifm");
55     _mul_sqrt->name("mul_sqrt");
56     _erf->name("erf");
57     _erf_out->name("erf_out");
58     _add_one->name("add_one");
59     _mul->name("mul");
60     _mul_half->name("mul_half");
61     _const_one->name("const_one");
62     _const_sqrt->name("const_sqrt");
63     _const_half->name("const_half");
64
65     _erf->custom_code("Erf");
66
67     _const_sqrt->dtype(loco::DataType::FLOAT32);
68     _const_sqrt->size<loco::DataType::FLOAT32>(1);
69     _const_sqrt->shape({1});
70     _const_sqrt->at<loco::DataType::FLOAT32>(0) = sqrtf(0.5f);
71     _const_sqrt->shape_status(luci::ShapeStatus::VALID);
72
73     _const_one->dtype(loco::DataType::FLOAT32);
74     _const_one->size<loco::DataType::FLOAT32>(1);
75     _const_one->shape({1});
76     _const_one->at<loco::DataType::FLOAT32>(0) = 1.0;
77     _const_one->shape_status(luci::ShapeStatus::VALID);
78
79     _const_half->dtype(loco::DataType::FLOAT32);
80     _const_half->size<loco::DataType::FLOAT32>(1);
81     _const_half->shape({1});
82     _const_half->at<loco::DataType::FLOAT32>(0) = 0.5;
83     _const_half->shape_status(luci::ShapeStatus::VALID);
84   }
85
86   void invalid_half() { _const_half->at<loco::DataType::FLOAT32>(0) = 0.1; }
87   void invalid_act() { _add_one->fusedActivationFunction(luci::FusedActFunc::RELU); }
88
89 protected:
90   luci::CircleAbs *_ifm = nullptr;
91   luci::CircleMul *_mul_sqrt = nullptr;
92   luci::CircleCustom *_erf = nullptr;
93   luci::CircleCustomOut *_erf_out = nullptr;
94   luci::CircleAdd *_add_one = nullptr;
95   luci::CircleMul *_mul = nullptr;
96   luci::CircleMul *_mul_half = nullptr;
97   luci::CircleConst *_const_sqrt = nullptr;
98   luci::CircleConst *_const_one = nullptr;
99   luci::CircleConst *_const_half = nullptr;
100 };
101
102 class FuseGeluTestGraph1 : public TestIOGraph, public GeluGraphlet
103 {
104 public:
105   FuseGeluTestGraph1() = default;
106
107   void init(void)
108   {
109     TestIOGraph::init({1}, {1});
110     GeluGraphlet::init(g());
111
112     _ifm->x(input());
113     _mul_sqrt->x(_ifm);
114     _mul_sqrt->y(_const_sqrt);
115     _erf->inputs(0, _mul_sqrt);
116     _erf_out->input(_erf);
117     _add_one->x(_erf_out);
118     _add_one->y(_const_one);
119     _mul->x(_ifm);
120     _mul->y(_add_one);
121     _mul_half->x(_mul);
122     _mul_half->y(_const_half);
123
124     output()->from(_mul_half);
125   }
126 };
127
128 class FuseGeluTestGraph2 : public TestIOGraph, public GeluGraphlet
129 {
130 public:
131   FuseGeluTestGraph2() = default;
132
133   void init(void)
134   {
135     TestIOGraph::init({1}, {1});
136     GeluGraphlet::init(g());
137
138     _ifm->x(input());
139     _mul_sqrt->x(_ifm);
140     _mul_sqrt->y(_const_sqrt);
141     _erf->inputs(0, _mul_sqrt);
142     _erf_out->input(_erf);
143     _add_one->x(_erf_out);
144     _add_one->y(_const_one);
145     _mul_half->x(_ifm);
146     _mul_half->y(_const_half);
147     _mul->x(_mul_half);
148     _mul->y(_add_one);
149
150     output()->from(_mul);
151   }
152 };
153
154 class FuseGeluTestNegGraph : public TestIOGraph, public GeluGraphlet
155 {
156 public:
157   FuseGeluTestNegGraph() = default;
158
159   void init(void)
160   {
161     TestIOGraph::init({1}, {1});
162     GeluGraphlet::init(g());
163
164     _ifm->x(input());
165     _mul_sqrt->x(_ifm);
166     // NOTE y is incorrect (should be _const_sqrt)
167     _mul_sqrt->y(_ifm);
168     _erf->inputs(0, _mul_sqrt);
169     _erf_out->input(_erf);
170     _add_one->x(_erf_out);
171     _add_one->y(_const_one);
172     _mul->x(_ifm);
173     _mul->y(_add_one);
174     _mul_half->x(_mul);
175     _mul_half->y(_const_half);
176
177     output()->from(_mul_half);
178   }
179 };
180
181 } // namespace
182
183 TEST(FuseGeluPassTest, name)
184 {
185   luci::FuseGeluPass pass;
186   auto const name = pass.name();
187   ASSERT_NE(nullptr, name);
188 }
189
190 TEST(FuseGeluPassTest, fuse_pattern1)
191 {
192   FuseGeluTestGraph1 g;
193   luci::FuseGeluPass pass;
194
195   g.init();
196
197   EXPECT_TRUE(pass.run(g.g()));
198 }
199
200 TEST(FuseGeluPassTest, fuse_pattern2)
201 {
202   FuseGeluTestGraph2 g;
203   luci::FuseGeluPass pass;
204
205   g.init();
206
207   EXPECT_TRUE(pass.run(g.g()));
208 }
209
210 TEST(FuseGeluPassTest, fuse_invalid_half_NEG)
211 {
212   FuseGeluTestNegGraph g;
213   luci::FuseGeluPass pass;
214
215   g.init();
216   g.invalid_half();
217
218   EXPECT_FALSE(pass.run(g.g()));
219 }
220
221 TEST(FuseGeluPassTest, fuse_pattern2_invalid_half_NEG)
222 {
223   FuseGeluTestGraph2 g;
224   luci::FuseGeluPass pass;
225
226   g.init();
227   g.invalid_half();
228
229   EXPECT_FALSE(pass.run(g.g()));
230 }
231
232 TEST(FuseGeluPassTest, fuse_invalid_act_NEG)
233 {
234   FuseGeluTestNegGraph g;
235   luci::FuseGeluPass pass;
236
237   g.init();
238   g.invalid_act();
239
240   EXPECT_FALSE(pass.run(g.g()));
241 }
242
243 TEST(FuseGeluPassTest, fuse_NEG)
244 {
245   FuseGeluTestNegGraph g;
246   luci::FuseGeluPass pass;
247
248   g.init();
249
250   EXPECT_FALSE(pass.run(g.g()));
251 }