2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
19 #include "helpers/NodeFiller.h"
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
27 template <class CIRCLENODE>
28 void replace_with_relu(luci::CircleNode *target, luci::CircleNode *feature,
29 const std::string &relu_name)
31 assert(target != nullptr);
32 assert(feature != nullptr);
34 auto relu = target->graph()->nodes()->create<CIRCLENODE>();
35 relu->features(feature);
36 relu->name(relu_name);
37 luci::add_origin(relu, luci::get_origin(target));
39 replace(target).with(relu);
47 * Fuse Mul-Add to TransposeConv if possible.
49 * NOTE TF's BatchNormalization is converted to Mul and Add.
52 * | [CircleConst]/[CircleOutputExclude]
55 * [CircleTransposeConv] [CircleConst]
57 * [CircleMul] [CircleConst]
63 * | [CircleConst]/[CircleOutputExclude]
64 * +-------------------------------------+ / [CircleConst]
66 * | [CircleTransposeConv] [CircleConst]
68 * | / [CircleConst] [CircleMul] [CircleConst]
70 * [CircleTransposeConv] [CircleAdd]
72 * ([CircleRelu]/[CircleRelu6])
75 * Note: CircleRelu or CircleRelu6 is inserted if Add activation is ReLU/ReLU6
77 bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
79 assert(add != nullptr);
81 // Find the pattern of CircleTransposeConv - CircleMul - CircleAdd
82 luci::CircleConst *scale = nullptr;
83 luci::CircleConst *shift = nullptr;
84 luci::CircleTransposeConv *tconv = nullptr;
85 luci::CircleMul *mul = nullptr;
86 if (not luci::fill(&shift, &mul).with_commutative_args_of(add))
88 if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul))
90 // skip if tconv has fused activation
91 if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
94 // check scale and shift constant attributes
95 // TODO maybe rank check is not needed
96 if (scale->rank() != 1 && scale->rank() != 4)
98 if (shift->rank() != 1 && shift->rank() != 4)
100 // check mul, add attributes
101 if (mul->dtype() != loco::DataType::FLOAT32)
103 if (add->dtype() != loco::DataType::FLOAT32)
105 if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
106 add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
107 add->fusedActivationFunction() != luci::FusedActFunc::RELU)
110 // tconv bias is optional
111 auto bias = dynamic_cast<luci::CircleConst *>(tconv->bias());
113 // get weight of tconv
114 auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
117 if (filter->dtype() != loco::DataType::FLOAT32)
119 if (filter->rank() != 4)
122 auto filter_out_chn = filter->dim(0).value();
123 // allow scale/shift and bias shape of [N], [1,1,1,N]; BN works for "channel-wise"
124 auto srank = scale->rank() - 1;
125 if (filter_out_chn != scale->dim(srank).value())
127 for (uint32_t d = 0; d < srank; ++d)
129 if (1 != scale->dim(d).value())
132 srank = shift->rank() - 1;
133 if (filter_out_chn != shift->dim(srank).value())
135 for (uint32_t d = 0; d < srank; ++d)
137 if (1 != shift->dim(d).value())
142 if (bias->dtype() != loco::DataType::FLOAT32)
144 srank = bias->rank() - 1;
145 if (filter_out_chn != bias->dim(srank).value())
147 for (uint32_t d = 0; d < srank; ++d)
149 if (1 != bias->dim(d).value())
154 auto name = add->name();
155 assert(name.length() > 0);
157 loco::Graph *graph = add->graph();
158 luci::CircleTransposeConv *fused_tconv = graph->nodes()->create<luci::CircleTransposeConv>();
159 luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>();
160 luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>();
162 auto filter_height = filter->dim(1).value();
163 auto filter_width = filter->dim(2).value();
164 auto filter_in_chn = filter->dim(3).value();
167 fused_filter->dtype(filter->dtype());
168 fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>());
169 fused_filter->rank(4);
170 fused_filter->dim(0).set(filter_out_chn);
171 fused_filter->dim(1).set(filter_height);
172 fused_filter->dim(2).set(filter_width);
173 fused_filter->dim(3).set(filter_in_chn);
174 fused_filter->shape_status(luci::ShapeStatus::VALID);
175 fused_filter->name(name + "/TransposeConv/filter");
177 // fused filter weight = filter weight * mul(scale) + add(shift)
178 for (uint32_t c = 0; c < filter_out_chn; c++)
180 for (uint32_t h = 0; h < filter_height; h++)
182 for (uint32_t w = 0; w < filter_width; w++)
184 for (uint32_t b = 0; b < filter_in_chn; b++)
186 uint32_t offset = c * filter_height * filter_width * filter_in_chn +
187 h * filter_width * filter_in_chn + w * filter_in_chn + b;
188 fused_filter->at<loco::DataType::FLOAT32>(offset) =
189 filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c);
195 // Copy fused_bias from shift
196 fused_bias->dtype(shift->dtype());
197 fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>());
199 fused_bias->dim(0).set(filter_out_chn);
200 fused_bias->shape_status(luci::ShapeStatus::VALID);
201 for (uint32_t c = 0; c < filter_out_chn; ++c)
203 fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c);
206 fused_bias->at<loco::DataType::FLOAT32>(c) +=
207 bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c);
210 fused_bias->name(name + "/TransposeConv/bias");
212 // set new tconv properties
213 fused_tconv->inputSizes(tconv->inputSizes());
214 fused_tconv->filter(fused_filter);
215 fused_tconv->outBackprop(tconv->outBackprop());
216 fused_tconv->bias(fused_bias);
217 fused_tconv->padding(tconv->padding());
218 fused_tconv->stride()->h(tconv->stride()->h());
219 fused_tconv->stride()->w(tconv->stride()->w());
220 fused_tconv->name(name + "/TransposeConv");
221 // TODO set activation from Add and remove adding following Relu/Relu6 Op
222 // when all of our backends supports fused activation of TransposeConv
223 fused_tconv->fusedActivationFunction(luci::FusedActFunc::NONE);
224 luci::add_origin(fused_tconv,
225 luci::composite_origin(
226 {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
229 luci::add_origin(fused_tconv, luci::get_origin(bias));
232 switch (add->fusedActivationFunction())
234 case luci::FusedActFunc::RELU6:
235 replace_with_relu<luci::CircleRelu6>(add, fused_tconv, name + "/Relu6");
238 case luci::FusedActFunc::RELU:
239 replace_with_relu<luci::CircleRelu>(add, fused_tconv, name + "/Relu");
242 case luci::FusedActFunc::NONE:
243 replace(add).with(fused_tconv);
259 bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
261 bool changed = false;
262 for (auto node : loco::active_nodes(loco::output_nodes(g)))
264 if (auto add = dynamic_cast<luci::CircleAdd *>(node))
266 if (fused_batch_norm_with_tconv(add))