Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / DecomposeHardSwishPass.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/DecomposeHardSwishPass.h"
18
19 #include "helpers/NodeFiller.h"
20 #include "helpers/TypeMapper.h"
21
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/Profile/CircleNodeOrigin.h>
24
25 namespace
26 {
27 /**
28  *  BEFORE
29  *        [CircleNode]
30  *              |
31  *              |
32  *      [CircleHardSwish]
33  *              |
34  *              |
35  *        [CircleNode]
36  *
37  *
38  *  AFTER
39  *
40  *      [CircleNode]  [CircleConst]
41  *          |    \       /
42  *          |     \     /
43  *          |   [CircleAdd]
44  *          |        |
45  *          |        |
46  *          \  [CircleRelu6] [CircleConst]
47  *           \        \        /
48  *            \        \      /
49  *             \      [CircleMul]
50  *              \       /
51  *               \     /
52  *             [CircleMul]
53  *                  |
54  *                  |
55  *             [CircleNode]
56  *
57  */
58 bool decompose_hardswish(luci::CircleHardSwish *hardswish)
59 {
60   if (not hardswish)
61     return false;
62
63   if (hardswish->dtype() != loco::DataType::FLOAT32)
64     return false;
65
66   auto g = hardswish->graph();
67
68   auto name = hardswish->name();
69   assert(name.length() > 0);
70
71   // Create a const for CircleAdd operation
72   auto add_const = g->nodes()->create<luci::CircleConst>();
73   add_const->shape({}); // scalar
74   add_const->dtype(loco::DataType::FLOAT32);
75   add_const->rank(0);
76   add_const->size<loco::DataType::FLOAT32>(1);
77   add_const->at<loco::DataType::FLOAT32>(0) = 3.;
78   add_const->name(name + "/Add/const");
79   luci::add_origin(add_const, luci::get_origin(hardswish));
80
81   // Create an Add operation
82   auto add = g->nodes()->create<luci::CircleAdd>();
83   add->fusedActivationFunction(luci::FusedActFunc::NONE);
84   add->x(hardswish->features());
85   add->y(add_const);
86   add->name(name + "/Add");
87   luci::add_origin(add, luci::get_origin(hardswish));
88
89   // Create a Relu6 operation
90   auto relu6 = g->nodes()->create<luci::CircleRelu6>();
91   relu6->features(add);
92   relu6->name(name + "/Relu6");
93   luci::add_origin(relu6, luci::get_origin(hardswish));
94
95   // Create a const for CircleMul operation
96   auto mul_const = g->nodes()->create<luci::CircleConst>();
97   mul_const->shape({}); // scalar
98   mul_const->dtype(loco::DataType::FLOAT32);
99   mul_const->rank(0);
100   mul_const->size<loco::DataType::FLOAT32>(1);
101   mul_const->at<loco::DataType::FLOAT32>(0) = 1. / 6.;
102   mul_const->name(name + "/Mul/const");
103   luci::add_origin(mul_const, luci::get_origin(hardswish));
104
105   // Create first Mul operation
106   auto mul1 = g->nodes()->create<luci::CircleMul>();
107   mul1->fusedActivationFunction(luci::FusedActFunc::NONE);
108   mul1->x(relu6);
109   mul1->y(mul_const);
110   mul1->name(name + "/Mul1");
111   luci::add_origin(mul1, luci::get_origin(hardswish));
112
113   // Create second Mul operation
114   auto mul2 = g->nodes()->create<luci::CircleMul>();
115   mul2->fusedActivationFunction(luci::FusedActFunc::NONE);
116   mul2->x(hardswish->features());
117   mul2->y(mul1);
118   mul2->name(name + "/Mul2");
119   luci::add_origin(mul2, luci::get_origin(hardswish));
120
121   replace(hardswish).with(mul2);
122
123   return true;
124 }
125
126 } // namespace
127
128 namespace luci
129 {
130
131 bool DecomposeHardSwishPass::run(loco::Graph *g)
132 {
133   bool changed = false;
134
135   for (auto node : loco::active_nodes(loco::output_nodes(g)))
136   {
137     if (auto hardswish = dynamic_cast<luci::CircleHardSwish *>(node))
138     {
139       if (decompose_hardswish(hardswish))
140         changed = true;
141     }
142   }
143
144   return changed;
145 }
146
147 } // namespace luci