Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseBatchNormWithTConv.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseBatchNormWithTConv.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 namespace
22 {
23 /**
24  *  NOTE TF's fusedBatchNorm is converted to mul and add of Circle.
25  *
26  *  BEFORE
27  *
28  *         [CircleTransposeConv]
29  *                  |
30  *                [mul]
31  *                  |
32  *                [add]
33  *  AFTER
34  *
35  *         [CircleTransposeConv]
36  */
37 bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv)
38 {
39   // check whether it has bias or not. This optimization works only if it doesn't.
40   auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
41   if (not bias)
42     return false;
43
44   // get weight of tconv
45   auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
46   if (not filter)
47     return false;
48   if (filter->dtype() != loco::DataType::FLOAT32)
49     return false;
50
51   // get mul node
52   auto tconv_output = loco::succs(tconv);
53   assert(tconv_output.size() == 1);
54   auto mul = dynamic_cast<luci::CircleMul *>(*tconv_output.begin());
55   if (not mul)
56     return false;
57   if (mul->dtype() != loco::DataType::FLOAT32)
58     return false;
59
60   // get add node
61   auto mul_output = loco::succs(mul);
62   assert(mul_output.size() == 1);
63   auto add = dynamic_cast<luci::CircleAdd *>(*mul_output.begin());
64   if (not add)
65     return false;
66   if (add->dtype() != loco::DataType::FLOAT32)
67     return false;
68   if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
69       add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
70     return false;
71
72   // get scale of batchnorm
73   auto scale = dynamic_cast<luci::CircleConst *>(mul->y());
74   if (not scale)
75     return false;
76
77   // scale dim(0) == tconv filter channel dim
78   if (filter->rank() != 4)
79     return false;
80   auto filter_channel_dim = filter->dim(3).value();
81   if (scale->rank() != 1)
82     return false;
83   auto scale_dim = scale->dim(0).value();
84   if (filter_channel_dim != scale_dim)
85     return false;
86
87   // get shift of batchnorm
88   auto shift = dynamic_cast<luci::CircleConst *>(add->y());
89   if (not shift)
90     return false;
91
92   // shift dim(0) == tconv filter channel dim
93   if (shift->rank() != 1)
94     return false;
95   auto shift_dim = shift->dim(0).value();
96   if (filter_channel_dim != shift_dim)
97     return false;
98
99   // filter weight = filter weight * mul(scale) + add(shift)
100   uint32_t filter_batch_dim = filter->dim(0).value();
101   uint32_t filter_height_dim = filter->dim(1).value();
102   uint32_t filter_width_dim = filter->dim(2).value();
103   for (uint32_t c = 0; c < filter_channel_dim; c++)
104   {
105     for (uint32_t n = 0; n < filter_batch_dim; n++)
106     {
107       for (uint32_t h = 0; h < filter_height_dim; h++)
108       {
109         for (uint32_t w = 0; w < filter_width_dim; w++)
110         {
111           uint32_t offset = n * filter_height_dim * filter_width_dim * filter_channel_dim +
112                             h * filter_width_dim * filter_channel_dim + w * filter_channel_dim + c;
113           filter->at<loco::DataType::FLOAT32>(offset) *= scale->at<loco::DataType::FLOAT32>(c);
114         }
115       }
116     }
117   }
118
119   // fuse shift with transposed conv
120   tconv->bias(shift);
121
122   if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
123   {
124     // separate relu op from add op
125     auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
126     relu->features(tconv);
127
128     // remove mul node
129     replace(add).with(relu);
130   }
131   else
132   {
133     replace(add).with(tconv);
134   }
135
136   return true;
137 }
138
139 } // namespace
140
141 namespace luci
142 {
143
144 bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
145 {
146   bool changed = false;
147   for (auto node : loco::active_nodes(loco::output_nodes(g)))
148   {
149     auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
150     if (not tconv)
151       continue;
152
153     changed |= fused_batch_norm_with_tconv(tconv);
154   }
155
156   return changed;
157 }
158
159 } // namespace luci