2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
19 #include "helpers/NodeFiller.h"
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
27 template <class CIRCLENODE>
28 void replace_with_relu(luci::CircleNode *target, luci::CircleNode *feature,
29 const std::string &relu_name)
31 assert(target != nullptr);
32 assert(feature != nullptr);
34 auto relu = target->graph()->nodes()->create<CIRCLENODE>();
35 relu->features(feature);
36 relu->name(relu_name);
37 luci::add_origin(relu, luci::get_origin(target));
39 replace(target).with(relu);
47 * Fuse Mul-Add to TransposeConv if possible.
49 * NOTE TF's BatchNormalization is converted to Mul and Add.
52 * | [CircleConst]/[CircleOutputExclude]
55 * [CircleTransposeConv] [CircleConst]
57 * [CircleMul] [CircleConst]
63 * | [CircleConst]/[CircleOutputExclude]
64 * +-------------------------------------+ / [CircleConst]
66 * | [CircleTransposeConv] [CircleConst]
68 * | / [CircleConst] [CircleMul] [CircleConst]
70 * [CircleTransposeConv] [CircleAdd]
72 * ([CircleRelu]/[CircleRelu6])
75 * Note: CircleRelu or CircleRelu6 is inserted if Add activation is ReLU/ReLU6
77 bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
79 assert(add != nullptr);
81 // Find the pattern of CircleTransposeConv - CircleMul - CircleAdd
82 luci::CircleConst *scale = nullptr;
83 luci::CircleConst *shift = nullptr;
84 luci::CircleTransposeConv *tconv = nullptr;
85 luci::CircleMul *mul = nullptr;
86 if (not luci::fill(&shift, &mul).with_commutative_args_of(add))
88 if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul))
91 // check scale and shift constant attributes
92 // TODO maybe rank check is not needed
93 if (scale->rank() != 1 && scale->rank() != 4)
95 if (shift->rank() != 1 && shift->rank() != 4)
97 // check mul, add attributes
98 if (mul->dtype() != loco::DataType::FLOAT32)
100 if (add->dtype() != loco::DataType::FLOAT32)
102 if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
103 add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
104 add->fusedActivationFunction() != luci::FusedActFunc::RELU)
107 // tconv bias is optional
108 auto bias = dynamic_cast<luci::CircleConst *>(tconv->bias());
110 // get weight of tconv
111 auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
114 if (filter->dtype() != loco::DataType::FLOAT32)
116 if (filter->rank() != 4)
119 auto filter_out_chn = filter->dim(0).value();
120 // allow scale/shift and bias shape of [N], [1,1,1,N]; BN works for "channel-wise"
121 auto srank = scale->rank() - 1;
122 if (filter_out_chn != scale->dim(srank).value())
124 for (uint32_t d = 0; d < srank; ++d)
126 if (1 != scale->dim(d).value())
129 srank = shift->rank() - 1;
130 if (filter_out_chn != shift->dim(srank).value())
132 for (uint32_t d = 0; d < srank; ++d)
134 if (1 != shift->dim(d).value())
139 if (bias->dtype() != loco::DataType::FLOAT32)
141 srank = bias->rank() - 1;
142 if (filter_out_chn != bias->dim(srank).value())
144 for (uint32_t d = 0; d < srank; ++d)
146 if (1 != bias->dim(d).value())
151 auto name = add->name();
152 assert(name.length() > 0);
154 loco::Graph *graph = add->graph();
155 luci::CircleTransposeConv *fused_tconv = graph->nodes()->create<luci::CircleTransposeConv>();
156 luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>();
157 luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>();
159 auto filter_height = filter->dim(1).value();
160 auto filter_width = filter->dim(2).value();
161 auto filter_in_chn = filter->dim(3).value();
164 fused_filter->dtype(filter->dtype());
165 fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>());
166 fused_filter->rank(4);
167 fused_filter->dim(0).set(filter_out_chn);
168 fused_filter->dim(1).set(filter_height);
169 fused_filter->dim(2).set(filter_width);
170 fused_filter->dim(3).set(filter_in_chn);
171 fused_filter->shape_status(luci::ShapeStatus::VALID);
172 fused_filter->name(name + "/TransposeConv/filter");
174 // fused filter weight = filter weight * mul(scale) + add(shift)
175 for (uint32_t c = 0; c < filter_out_chn; c++)
177 for (uint32_t h = 0; h < filter_height; h++)
179 for (uint32_t w = 0; w < filter_width; w++)
181 for (uint32_t b = 0; b < filter_in_chn; b++)
183 uint32_t offset = c * filter_height * filter_width * filter_in_chn +
184 h * filter_width * filter_in_chn + w * filter_in_chn + b;
185 fused_filter->at<loco::DataType::FLOAT32>(offset) =
186 filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c);
192 // Copy fused_bias from shift
193 fused_bias->dtype(shift->dtype());
194 fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>());
196 fused_bias->dim(0).set(filter_out_chn);
197 fused_bias->shape_status(luci::ShapeStatus::VALID);
198 for (uint32_t c = 0; c < filter_out_chn; ++c)
200 fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c);
203 fused_bias->at<loco::DataType::FLOAT32>(c) +=
204 bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c);
207 fused_bias->name(name + "/TransposeConv/bias");
209 // set new tconv properties
210 fused_tconv->inputSizes(tconv->inputSizes());
211 fused_tconv->filter(fused_filter);
212 fused_tconv->outBackprop(tconv->outBackprop());
213 fused_tconv->bias(fused_bias);
214 fused_tconv->padding(tconv->padding());
215 fused_tconv->stride()->h(tconv->stride()->h());
216 fused_tconv->stride()->w(tconv->stride()->w());
217 fused_tconv->name(name + "/TransposeConv");
218 luci::add_origin(fused_tconv,
219 luci::composite_origin(
220 {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
223 luci::add_origin(fused_tconv, luci::get_origin(bias));
226 switch (add->fusedActivationFunction())
228 case luci::FusedActFunc::RELU6:
229 replace_with_relu<luci::CircleRelu6>(add, fused_tconv, name + "/Relu6");
232 case luci::FusedActFunc::RELU:
233 replace_with_relu<luci::CircleRelu>(add, fused_tconv, name + "/Relu");
236 case luci::FusedActFunc::NONE:
237 replace(add).with(fused_tconv);
253 bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
255 bool changed = false;
256 for (auto node : loco::active_nodes(loco::output_nodes(g)))
258 if (auto add = dynamic_cast<luci::CircleAdd *>(node))
260 if (fused_batch_norm_with_tconv(add))