Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseGeluPass.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseGeluPass.h"
18 #include "helpers/NodeFiller.h"
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/Service/CircleNodeClone.h>
24
25 #include <cmath>
26
27 #include <cassert>
28
29 // Helper to fuse Gelu
30 namespace
31 {
32
33 // Float comparison
34 bool same(float a, float b) { return fabs(a - b) < 1e-5; }
35
36 class GeluPatternBase
37 {
38 public:
39   GeluPatternBase(luci::CircleMul *candidate) { _pattern_last_node = candidate; }
40
41   virtual ~GeluPatternBase() = default;
42
43 public:
44   virtual bool matched() = 0;
45
46 public:
47   luci::CircleNode *_ifm = nullptr;
48   luci::CircleMul *_mul_sqrt = nullptr;
49   luci::CircleCustom *_erf = nullptr;
50   luci::CircleCustomOut *_erf_out = nullptr;
51   luci::CircleAdd *_add_one = nullptr;
52   luci::CircleMul *_mul = nullptr;
53   luci::CircleMul *_mul_half = nullptr;
54   luci::CircleConst *_const_sqrt = nullptr;
55   luci::CircleConst *_const_one = nullptr;
56   luci::CircleConst *_const_half = nullptr;
57   luci::CircleMul *_pattern_last_node = nullptr;
58 };
59
60 /**
61  * Below diagram shows Gelu pattern to fuse.
62  * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0)))
63  * - the below pattern will be replaced with one Gelu
64  *
65  *           [In]
66  *            |
67  *            V
68  *     +---- ifm
69  *     |      |
70  *     |      V
71  *     |  mul_sqrt (1/sqrt(2) = 0.707106..)
72  *     |      |
73  *     |      V
74  *     |     erf
75  *     |      |
76  *     |      V
77  *     |   add_one (1.0)
78  *     |      |
79  *     |      V
80  *     +---> mul
81  *            |
82  *            V
83  *         mul_half (0.5)
84  *            |
85  *            V
86  *          [Out]
87  *
88  */
89 class GeluPattern1 final : public GeluPatternBase
90 {
91 public:
92   GeluPattern1(luci::CircleMul *candidate) : GeluPatternBase(candidate)
93   {
94     assert(candidate);
95     _mul_half = candidate;
96   }
97
98 public:
99   bool matched() override;
100 };
101
102 /**
103  * Below diagram shows Gelu pattern to fuse.
104  * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0)))
105  * - the below pattern will be replaced with one Gelu
106  *
107  *                  [In]
108  *                   |
109  *                   V
110  *     +----------- ifm
111  *     |             |
112  *     |             V
113  *     |          mul_sqrt (1/sqrt(2) = 0.707106..)
114  *     |             |
115  *     |             V
116  *     |            erf
117  * mul_half (0.5)    |
118  *     |             V
119  *     |         add_one (1.0)
120  *     |             |
121  *     |             V
122  *     +----------> mul
123  *                   |
124  *                   |
125  *                   V
126  *                 [Out]
127  *
128  */
129 class GeluPattern2 final : public GeluPatternBase
130 {
131 public:
132   GeluPattern2(luci::CircleMul *candidate) : GeluPatternBase(candidate)
133   {
134     assert(candidate);
135     _mul = candidate;
136   }
137
138   ~GeluPattern2() override = default;
139
140 public:
141   bool matched() override;
142 };
143
144 #define CHECK_OR_FALSE(condition) \
145   if (not(condition))             \
146     return false;
147
148 bool GeluPattern1::matched()
149 {
150   // check pattern
151   CHECK_OR_FALSE(luci::fill(&_mul, &_const_half).with_commutative_args_of(_mul_half));
152   CHECK_OR_FALSE(luci::fill(&_ifm, &_add_one).with_commutative_args_of(_mul));
153   CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one));
154
155   if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input()))
156     _erf = erf;
157
158   CHECK_OR_FALSE(_erf != nullptr);
159
160   // Check erf
161   CHECK_OR_FALSE(_erf->custom_code() == "Erf");
162   CHECK_OR_FALSE(_erf->numInputs() == 1);
163   CHECK_OR_FALSE(_erf->numOutputs() == 1);
164
165   if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0)))
166     _mul_sqrt = mul_sqrt;
167
168   CHECK_OR_FALSE(_mul_sqrt != nullptr);
169
170   CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt));
171
172   CHECK_OR_FALSE(_mul_sqrt->x() == _ifm);
173   CHECK_OR_FALSE(_mul->x() == _ifm);
174
175   // Check Activation to be NONE
176   CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE);
177   CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE);
178   CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE);
179   CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE);
180
181   // check _const_sqrt condition
182   CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32);
183   CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1);
184   CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f)));
185
186   // check if _const_half is 0.5 (fp32)
187   CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32);
188   CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1);
189   CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5);
190
191   // check _const_one condition
192   CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32);
193   CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1);
194   CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1);
195
196   return true;
197 }
198
199 bool GeluPattern2::matched()
200 {
201   // check pattern
202   CHECK_OR_FALSE(luci::fill(&_mul_half, &_add_one).with_commutative_args_of(_mul));
203   CHECK_OR_FALSE(luci::fill(&_ifm, &_const_half).with_commutative_args_of(_mul_half));
204   CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one));
205
206   CHECK_OR_FALSE(_mul_half->x() == _ifm);
207
208   if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input()))
209     _erf = erf;
210
211   CHECK_OR_FALSE(_erf != nullptr);
212
213   // Check erf
214   CHECK_OR_FALSE(_erf->custom_code() == "Erf");
215   CHECK_OR_FALSE(_erf->numInputs() == 1);
216   CHECK_OR_FALSE(_erf->numOutputs() == 1);
217
218   if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0)))
219     _mul_sqrt = mul_sqrt;
220
221   CHECK_OR_FALSE(_mul_sqrt != nullptr);
222
223   CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt));
224
225   CHECK_OR_FALSE(_mul_sqrt->x() == _ifm);
226
227   // Check Activation to be NONE
228   CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE);
229   CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE);
230   CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE);
231   CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE);
232
233   // check _const_sqrt condition
234   CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32);
235   CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1);
236   CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f)));
237
238   // check if _const_half is 0.5 (fp32)
239   CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32);
240   CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1);
241   CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5);
242
243   // check _const_one condition
244   CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32);
245   CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1);
246   CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1);
247
248   return true;
249 }
250
251 #undef CHECK_OR_FALSE
252
253 class FuseGelu final
254 {
255 public:
256   FuseGelu(const GeluPatternBase *p) : _p(p) {}
257
258 public:
259   void apply(void);
260
261 private:
262   luci::CircleGelu *create_gelu(loco::Graph *graph);
263
264 private:
265   const GeluPatternBase *_p;
266 };
267
268 luci::CircleGelu *FuseGelu::create_gelu(loco::Graph *graph)
269 {
270   assert(graph);
271
272   auto gelu = graph->nodes()->create<luci::CircleGelu>();
273   gelu->features(_p->_ifm);
274   // TODO Support approximate = True pattern
275   gelu->approximate(false);
276   gelu->name(_p->_pattern_last_node->name() + "_gelu");
277   return gelu;
278 }
279
280 void FuseGelu::apply()
281 {
282   auto graph = _p->_pattern_last_node->graph();
283
284   auto gelu = create_gelu(graph);
285
286   // set origin
287   std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
288     luci::get_origin(_p->_mul_sqrt), luci::get_origin(_p->_erf), luci::get_origin(_p->_add_one),
289     luci::get_origin(_p->_mul), luci::get_origin(_p->_mul_half)};
290
291   luci::add_origin(gelu, luci::composite_origin(origin_vec));
292
293   replace(_p->_pattern_last_node).with(gelu);
294 }
295
296 } // namespace
297
298 namespace
299 {
300
301 bool fuse_gelu(luci::CircleMul *mul)
302 {
303   assert(mul);
304
305   // check first pattern
306   GeluPattern1 pattern(mul);
307   if (pattern.matched())
308   {
309     FuseGelu fuse(&pattern);
310     fuse.apply();
311     return true;
312   }
313
314   // check second pattern
315   GeluPattern2 pattern2(mul);
316   if (pattern2.matched())
317   {
318     FuseGelu fuse(&pattern2);
319     fuse.apply();
320     return true;
321   }
322   return false;
323 }
324
325 } // namespace
326
327 namespace luci
328 {
329
330 bool FuseGeluPass::run(loco::Graph *g)
331 {
332   bool changed = false;
333
334   for (auto node : loco::active_nodes(loco::output_nodes(g)))
335   {
336     auto mul = dynamic_cast<luci::CircleMul *>(node);
337     if (not mul)
338       continue;
339
340     if (fuse_gelu(mul))
341       changed = true;
342   }
343
344   return changed;
345 }
346
347 } // namespace luci