265a8398bc2b8b722168d29d71f8b7e4b166b40c
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseBatchNormWithTConvPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
18
19 #include "helpers/NodeFiller.h"
20
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23
24 namespace
25 {
26
27 template <class CIRCLENODE>
28 void replace_with_relu(luci::CircleNode *target, luci::CircleNode *feature,
29                        const std::string &relu_name)
30 {
31   assert(target != nullptr);
32   assert(feature != nullptr);
33
34   auto relu = target->graph()->nodes()->create<CIRCLENODE>();
35   relu->features(feature);
36   relu->name(relu_name);
37   luci::add_origin(relu, luci::get_origin(target));
38
39   replace(target).with(relu);
40 }
41
42 } // namespace
43
44 namespace
45 {
46 /**
47  *  Fuse Mul-Add to TransposeConv if possible.
48  *
49  *  NOTE TF's BatchNormalization is converted to Mul and Add.
50  *
51  *  BEFORE
52  *                     |   [CircleConst]/[CircleOutputExclude]
53  *                     |   / [CircleConst]
54  *                     |  / /
55  *     [CircleTransposeConv]  [CircleConst]
56  *                     |     /
57  *                [CircleMul] [CircleConst]
58  *                     |     /
59  *                [CircleAdd]
60  *                     |
61  *
62  *  AFTER
63  *                     |                                         [CircleConst]/[CircleOutputExclude]
64  *                     +-------------------------------------+   / [CircleConst]
65  *                     |                                     |  / /
66  *                     |                     [CircleTransposeConv]  [CircleConst]
67  *                     |    [CircleConst]                    |     /
68  *                     |   / [CircleConst]              [CircleMul] [CircleConst]
69  *                     |  / /                                |     /
70  *     [CircleTransposeConv]                            [CircleAdd]
71  *                     |
72  *        ([CircleRelu]/[CircleRelu6])
73  *                     |
74  *
75  * Note: CircleRelu or CircleRelu6 is inserted if Add activation is ReLU/ReLU6
76  */
77 bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
78 {
79   assert(add != nullptr);
80
81   // Find the pattern of CircleTransposeConv - CircleMul - CircleAdd
82   luci::CircleConst *scale = nullptr;
83   luci::CircleConst *shift = nullptr;
84   luci::CircleTransposeConv *tconv = nullptr;
85   luci::CircleMul *mul = nullptr;
86   if (not luci::fill(&shift, &mul).with_commutative_args_of(add))
87     return false;
88   if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul))
89     return false;
90
91   // check scale and shift constant attributes
92   // TODO maybe rank check is not needed
93   if (scale->rank() != 1 && scale->rank() != 4)
94     return false;
95   if (shift->rank() != 1 && shift->rank() != 4)
96     return false;
97   // check mul, add attributes
98   if (mul->dtype() != loco::DataType::FLOAT32)
99     return false;
100   if (add->dtype() != loco::DataType::FLOAT32)
101     return false;
102   if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
103       add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
104       add->fusedActivationFunction() != luci::FusedActFunc::RELU)
105     return false;
106
107   // tconv bias is optional
108   auto bias = dynamic_cast<luci::CircleConst *>(tconv->bias());
109
110   // get weight of tconv
111   auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
112   if (not filter)
113     return false;
114   if (filter->dtype() != loco::DataType::FLOAT32)
115     return false;
116   if (filter->rank() != 4)
117     return false;
118
119   auto filter_out_chn = filter->dim(0).value();
120   // allow scale/shift and bias shape of [N], [1,1,1,N]; BN works for "channel-wise"
121   auto srank = scale->rank() - 1;
122   if (filter_out_chn != scale->dim(srank).value())
123     return false;
124   for (uint32_t d = 0; d < srank; ++d)
125   {
126     if (1 != scale->dim(d).value())
127       return false;
128   }
129   srank = shift->rank() - 1;
130   if (filter_out_chn != shift->dim(srank).value())
131     return false;
132   for (uint32_t d = 0; d < srank; ++d)
133   {
134     if (1 != shift->dim(d).value())
135       return false;
136   }
137   if (bias)
138   {
139     if (bias->dtype() != loco::DataType::FLOAT32)
140       return false;
141     srank = bias->rank() - 1;
142     if (filter_out_chn != bias->dim(srank).value())
143       return false;
144     for (uint32_t d = 0; d < srank; ++d)
145     {
146       if (1 != bias->dim(d).value())
147         return false;
148     }
149   }
150
151   auto name = add->name();
152   assert(name.length() > 0);
153
154   loco::Graph *graph = add->graph();
155   luci::CircleTransposeConv *fused_tconv = graph->nodes()->create<luci::CircleTransposeConv>();
156   luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>();
157   luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>();
158
159   auto filter_height = filter->dim(1).value();
160   auto filter_width = filter->dim(2).value();
161   auto filter_in_chn = filter->dim(3).value();
162
163   // Copy filter shape
164   fused_filter->dtype(filter->dtype());
165   fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>());
166   fused_filter->rank(4);
167   fused_filter->dim(0).set(filter_out_chn);
168   fused_filter->dim(1).set(filter_height);
169   fused_filter->dim(2).set(filter_width);
170   fused_filter->dim(3).set(filter_in_chn);
171   fused_filter->shape_status(luci::ShapeStatus::VALID);
172   fused_filter->name(name + "/TransposeConv/filter");
173
174   // fused filter weight = filter weight * mul(scale) + add(shift)
175   for (uint32_t c = 0; c < filter_out_chn; c++)
176   {
177     for (uint32_t h = 0; h < filter_height; h++)
178     {
179       for (uint32_t w = 0; w < filter_width; w++)
180       {
181         for (uint32_t b = 0; b < filter_in_chn; b++)
182         {
183           uint32_t offset = c * filter_height * filter_width * filter_in_chn +
184                             h * filter_width * filter_in_chn + w * filter_in_chn + b;
185           fused_filter->at<loco::DataType::FLOAT32>(offset) =
186             filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c);
187         }
188       }
189     }
190   }
191
192   // Copy fused_bias from shift
193   fused_bias->dtype(shift->dtype());
194   fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>());
195   fused_bias->rank(1);
196   fused_bias->dim(0).set(filter_out_chn);
197   fused_bias->shape_status(luci::ShapeStatus::VALID);
198   for (uint32_t c = 0; c < filter_out_chn; ++c)
199   {
200     fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c);
201     if (bias != nullptr)
202     {
203       fused_bias->at<loco::DataType::FLOAT32>(c) +=
204         bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c);
205     }
206   }
207   fused_bias->name(name + "/TransposeConv/bias");
208
209   // set new tconv properties
210   fused_tconv->inputSizes(tconv->inputSizes());
211   fused_tconv->filter(fused_filter);
212   fused_tconv->outBackprop(tconv->outBackprop());
213   fused_tconv->bias(fused_bias);
214   fused_tconv->padding(tconv->padding());
215   fused_tconv->stride()->h(tconv->stride()->h());
216   fused_tconv->stride()->w(tconv->stride()->w());
217   fused_tconv->name(name + "/TransposeConv");
218   luci::add_origin(fused_tconv,
219                    luci::composite_origin(
220                      {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
221   if (bias != nullptr)
222   {
223     luci::add_origin(fused_tconv, luci::get_origin(bias));
224   }
225
226   switch (add->fusedActivationFunction())
227   {
228     case luci::FusedActFunc::RELU6:
229       replace_with_relu<luci::CircleRelu6>(add, fused_tconv, name + "/Relu6");
230       break;
231
232     case luci::FusedActFunc::RELU:
233       replace_with_relu<luci::CircleRelu>(add, fused_tconv, name + "/Relu");
234       break;
235
236     case luci::FusedActFunc::NONE:
237       replace(add).with(fused_tconv);
238       break;
239
240     default:
241       assert(false);
242       break;
243   }
244
245   return true;
246 }
247
248 } // namespace
249
250 namespace luci
251 {
252
253 bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
254 {
255   bool changed = false;
256   for (auto node : loco::active_nodes(loco::output_nodes(g)))
257   {
258     if (auto add = dynamic_cast<luci::CircleAdd *>(node))
259     {
260       if (fused_batch_norm_with_tconv(add))
261         changed = true;
262     }
263   }
264
265   return changed;
266 }
267
268 } // namespace luci