Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseBatchNormWithTConvPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
18
19 #include "helpers/NodeFiller.h"
20
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23
24 namespace
25 {
26
27 template <class CIRCLENODE>
28 void replace_with_relu(luci::CircleNode *target, luci::CircleNode *feature,
29                        const std::string &relu_name)
30 {
31   assert(target != nullptr);
32   assert(feature != nullptr);
33
34   auto relu = target->graph()->nodes()->create<CIRCLENODE>();
35   relu->features(feature);
36   relu->name(relu_name);
37   luci::add_origin(relu, luci::get_origin(target));
38
39   replace(target).with(relu);
40 }
41
42 } // namespace
43
44 namespace
45 {
46 /**
47  *  Fuse Mul-Add to TransposeConv if possible.
48  *
49  *  NOTE TF's BatchNormalization is converted to Mul and Add.
50  *
51  *  BEFORE
52  *                     |   [CircleConst]/[CircleOutputExclude]
53  *                     |   / [CircleConst]
54  *                     |  / /
55  *     [CircleTransposeConv]  [CircleConst]
56  *                     |     /
57  *                [CircleMul] [CircleConst]
58  *                     |     /
59  *                [CircleAdd]
60  *                     |
61  *
62  *  AFTER
63  *                     |                                         [CircleConst]/[CircleOutputExclude]
64  *                     +-------------------------------------+   / [CircleConst]
65  *                     |                                     |  / /
66  *                     |                     [CircleTransposeConv]  [CircleConst]
67  *                     |    [CircleConst]                    |     /
68  *                     |   / [CircleConst]              [CircleMul] [CircleConst]
69  *                     |  / /                                |     /
70  *     [CircleTransposeConv]                            [CircleAdd]
71  *                     |
72  *        ([CircleRelu]/[CircleRelu6])
73  *                     |
74  *
75  * Note: CircleRelu or CircleRelu6 is inserted if Add activation is ReLU/ReLU6
76  */
77 bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
78 {
79   assert(add != nullptr);
80
81   // Find the pattern of CircleTransposeConv - CircleMul - CircleAdd
82   luci::CircleConst *scale = nullptr;
83   luci::CircleConst *shift = nullptr;
84   luci::CircleTransposeConv *tconv = nullptr;
85   luci::CircleMul *mul = nullptr;
86   if (not luci::fill(&shift, &mul).with_commutative_args_of(add))
87     return false;
88   if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul))
89     return false;
90   // skip if tconv has fused activation
91   if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
92     return false;
93
94   // check scale and shift constant attributes
95   // TODO maybe rank check is not needed
96   if (scale->rank() != 1 && scale->rank() != 4)
97     return false;
98   if (shift->rank() != 1 && shift->rank() != 4)
99     return false;
100   // check mul, add attributes
101   if (mul->dtype() != loco::DataType::FLOAT32)
102     return false;
103   if (add->dtype() != loco::DataType::FLOAT32)
104     return false;
105   if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
106       add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
107       add->fusedActivationFunction() != luci::FusedActFunc::RELU)
108     return false;
109
110   // tconv bias is optional
111   auto bias = dynamic_cast<luci::CircleConst *>(tconv->bias());
112
113   // get weight of tconv
114   auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
115   if (not filter)
116     return false;
117   if (filter->dtype() != loco::DataType::FLOAT32)
118     return false;
119   if (filter->rank() != 4)
120     return false;
121
122   auto filter_out_chn = filter->dim(0).value();
123   // allow scale/shift and bias shape of [N], [1,1,1,N]; BN works for "channel-wise"
124   auto srank = scale->rank() - 1;
125   if (filter_out_chn != scale->dim(srank).value())
126     return false;
127   for (uint32_t d = 0; d < srank; ++d)
128   {
129     if (1 != scale->dim(d).value())
130       return false;
131   }
132   srank = shift->rank() - 1;
133   if (filter_out_chn != shift->dim(srank).value())
134     return false;
135   for (uint32_t d = 0; d < srank; ++d)
136   {
137     if (1 != shift->dim(d).value())
138       return false;
139   }
140   if (bias)
141   {
142     if (bias->dtype() != loco::DataType::FLOAT32)
143       return false;
144     srank = bias->rank() - 1;
145     if (filter_out_chn != bias->dim(srank).value())
146       return false;
147     for (uint32_t d = 0; d < srank; ++d)
148     {
149       if (1 != bias->dim(d).value())
150         return false;
151     }
152   }
153
154   auto name = add->name();
155   assert(name.length() > 0);
156
157   loco::Graph *graph = add->graph();
158   luci::CircleTransposeConv *fused_tconv = graph->nodes()->create<luci::CircleTransposeConv>();
159   luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>();
160   luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>();
161
162   auto filter_height = filter->dim(1).value();
163   auto filter_width = filter->dim(2).value();
164   auto filter_in_chn = filter->dim(3).value();
165
166   // Copy filter shape
167   fused_filter->dtype(filter->dtype());
168   fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>());
169   fused_filter->rank(4);
170   fused_filter->dim(0).set(filter_out_chn);
171   fused_filter->dim(1).set(filter_height);
172   fused_filter->dim(2).set(filter_width);
173   fused_filter->dim(3).set(filter_in_chn);
174   fused_filter->shape_status(luci::ShapeStatus::VALID);
175   fused_filter->name(name + "/TransposeConv/filter");
176
177   // fused filter weight = filter weight * mul(scale) + add(shift)
178   for (uint32_t c = 0; c < filter_out_chn; c++)
179   {
180     for (uint32_t h = 0; h < filter_height; h++)
181     {
182       for (uint32_t w = 0; w < filter_width; w++)
183       {
184         for (uint32_t b = 0; b < filter_in_chn; b++)
185         {
186           uint32_t offset = c * filter_height * filter_width * filter_in_chn +
187                             h * filter_width * filter_in_chn + w * filter_in_chn + b;
188           fused_filter->at<loco::DataType::FLOAT32>(offset) =
189             filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c);
190         }
191       }
192     }
193   }
194
195   // Copy fused_bias from shift
196   fused_bias->dtype(shift->dtype());
197   fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>());
198   fused_bias->rank(1);
199   fused_bias->dim(0).set(filter_out_chn);
200   fused_bias->shape_status(luci::ShapeStatus::VALID);
201   for (uint32_t c = 0; c < filter_out_chn; ++c)
202   {
203     fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c);
204     if (bias != nullptr)
205     {
206       fused_bias->at<loco::DataType::FLOAT32>(c) +=
207         bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c);
208     }
209   }
210   fused_bias->name(name + "/TransposeConv/bias");
211
212   // set new tconv properties
213   fused_tconv->inputSizes(tconv->inputSizes());
214   fused_tconv->filter(fused_filter);
215   fused_tconv->outBackprop(tconv->outBackprop());
216   fused_tconv->bias(fused_bias);
217   fused_tconv->padding(tconv->padding());
218   fused_tconv->stride()->h(tconv->stride()->h());
219   fused_tconv->stride()->w(tconv->stride()->w());
220   fused_tconv->name(name + "/TransposeConv");
221   // TODO set activation from Add and remove adding following Relu/Relu6 Op
222   //      when all of our backends supports fused activation of TransposeConv
223   fused_tconv->fusedActivationFunction(luci::FusedActFunc::NONE);
224   luci::add_origin(fused_tconv,
225                    luci::composite_origin(
226                      {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
227   if (bias != nullptr)
228   {
229     luci::add_origin(fused_tconv, luci::get_origin(bias));
230   }
231
232   switch (add->fusedActivationFunction())
233   {
234     case luci::FusedActFunc::RELU6:
235       replace_with_relu<luci::CircleRelu6>(add, fused_tconv, name + "/Relu6");
236       break;
237
238     case luci::FusedActFunc::RELU:
239       replace_with_relu<luci::CircleRelu>(add, fused_tconv, name + "/Relu");
240       break;
241
242     case luci::FusedActFunc::NONE:
243       replace(add).with(fused_tconv);
244       break;
245
246     default:
247       assert(false);
248       break;
249   }
250
251   return true;
252 }
253
254 } // namespace
255
256 namespace luci
257 {
258
259 bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
260 {
261   bool changed = false;
262   for (auto node : loco::active_nodes(loco::output_nodes(g)))
263   {
264     if (auto add = dynamic_cast<luci::CircleAdd *>(node))
265     {
266       if (fused_batch_norm_with_tconv(add))
267         changed = true;
268     }
269   }
270
271   return changed;
272 }
273
274 } // namespace luci