2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseGeluPass.h"
18 #include "helpers/NodeFiller.h"
20 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/Service/CircleNodeClone.h>
29 // Helper to fuse Gelu
34 bool same(float a, float b) { return fabs(a - b) < 1e-5; }
39 GeluPatternBase(luci::CircleMul *candidate) { _pattern_last_node = candidate; }
41 virtual ~GeluPatternBase() = default;
44 virtual bool matched() = 0;
47 luci::CircleNode *_ifm = nullptr;
48 luci::CircleMul *_mul_sqrt = nullptr;
49 luci::CircleCustom *_erf = nullptr;
50 luci::CircleCustomOut *_erf_out = nullptr;
51 luci::CircleAdd *_add_one = nullptr;
52 luci::CircleMul *_mul = nullptr;
53 luci::CircleMul *_mul_half = nullptr;
54 luci::CircleConst *_const_sqrt = nullptr;
55 luci::CircleConst *_const_one = nullptr;
56 luci::CircleConst *_const_half = nullptr;
57 luci::CircleMul *_pattern_last_node = nullptr;
61 * Below diagram shows Gelu pattern to fuse.
62 * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0)))
63 * - the below pattern will be replaced with one Gelu
71 * | mul_sqrt (1/sqrt(2) = 0.707106..)
89 class GeluPattern1 final : public GeluPatternBase
92 GeluPattern1(luci::CircleMul *candidate) : GeluPatternBase(candidate)
95 _mul_half = candidate;
99 bool matched() override;
103 * Below diagram shows Gelu pattern to fuse.
104 * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0)))
105 * - the below pattern will be replaced with one Gelu
113 * | mul_sqrt (1/sqrt(2) = 0.707106..)
129 class GeluPattern2 final : public GeluPatternBase
132 GeluPattern2(luci::CircleMul *candidate) : GeluPatternBase(candidate)
138 ~GeluPattern2() override = default;
141 bool matched() override;
144 #define CHECK_OR_FALSE(condition) \
145 if (not(condition)) \
148 bool GeluPattern1::matched()
151 CHECK_OR_FALSE(luci::fill(&_mul, &_const_half).with_commutative_args_of(_mul_half));
152 CHECK_OR_FALSE(luci::fill(&_ifm, &_add_one).with_commutative_args_of(_mul));
153 CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one));
155 if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input()))
158 CHECK_OR_FALSE(_erf != nullptr);
161 CHECK_OR_FALSE(_erf->custom_code() == "Erf");
162 CHECK_OR_FALSE(_erf->numInputs() == 1);
163 CHECK_OR_FALSE(_erf->numOutputs() == 1);
165 if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0)))
166 _mul_sqrt = mul_sqrt;
168 CHECK_OR_FALSE(_mul_sqrt != nullptr);
170 CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt));
172 CHECK_OR_FALSE(_mul_sqrt->x() == _ifm);
173 CHECK_OR_FALSE(_mul->x() == _ifm);
175 // Check Activation to be NONE
176 CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE);
177 CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE);
178 CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE);
179 CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE);
181 // check _const_sqrt condition
182 CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32);
183 CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1);
184 CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f)));
186 // check if _const_half is 0.5 (fp32)
187 CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32);
188 CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1);
189 CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5);
191 // check _const_one condition
192 CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32);
193 CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1);
194 CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1);
199 bool GeluPattern2::matched()
202 CHECK_OR_FALSE(luci::fill(&_mul_half, &_add_one).with_commutative_args_of(_mul));
203 CHECK_OR_FALSE(luci::fill(&_ifm, &_const_half).with_commutative_args_of(_mul_half));
204 CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one));
206 CHECK_OR_FALSE(_mul_half->x() == _ifm);
208 if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input()))
211 CHECK_OR_FALSE(_erf != nullptr);
214 CHECK_OR_FALSE(_erf->custom_code() == "Erf");
215 CHECK_OR_FALSE(_erf->numInputs() == 1);
216 CHECK_OR_FALSE(_erf->numOutputs() == 1);
218 if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0)))
219 _mul_sqrt = mul_sqrt;
221 CHECK_OR_FALSE(_mul_sqrt != nullptr);
223 CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt));
225 CHECK_OR_FALSE(_mul_sqrt->x() == _ifm);
227 // Check Activation to be NONE
228 CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE);
229 CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE);
230 CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE);
231 CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE);
233 // check _const_sqrt condition
234 CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32);
235 CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1);
236 CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f)));
238 // check if _const_half is 0.5 (fp32)
239 CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32);
240 CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1);
241 CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5);
243 // check _const_one condition
244 CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32);
245 CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1);
246 CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1);
251 #undef CHECK_OR_FALSE
256 FuseGelu(const GeluPatternBase *p) : _p(p) {}
262 luci::CircleGelu *create_gelu(loco::Graph *graph);
265 const GeluPatternBase *_p;
268 luci::CircleGelu *FuseGelu::create_gelu(loco::Graph *graph)
272 auto gelu = graph->nodes()->create<luci::CircleGelu>();
273 gelu->features(_p->_ifm);
274 // TODO Support approximate = True pattern
275 gelu->approximate(false);
276 gelu->name(_p->_pattern_last_node->name() + "_gelu");
280 void FuseGelu::apply()
282 auto graph = _p->_pattern_last_node->graph();
284 auto gelu = create_gelu(graph);
287 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
288 luci::get_origin(_p->_mul_sqrt), luci::get_origin(_p->_erf), luci::get_origin(_p->_add_one),
289 luci::get_origin(_p->_mul), luci::get_origin(_p->_mul_half)};
291 luci::add_origin(gelu, luci::composite_origin(origin_vec));
293 replace(_p->_pattern_last_node).with(gelu);
301 bool fuse_gelu(luci::CircleMul *mul)
305 // check first pattern
306 GeluPattern1 pattern(mul);
307 if (pattern.matched())
309 FuseGelu fuse(&pattern);
314 // check second pattern
315 GeluPattern2 pattern2(mul);
316 if (pattern2.matched())
318 FuseGelu fuse(&pattern2);
330 bool FuseGeluPass::run(loco::Graph *g)
332 bool changed = false;
334 for (auto node : loco::active_nodes(loco::output_nodes(g)))
336 auto mul = dynamic_cast<luci::CircleMul *>(node);