2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseBatchNormWithTConv.h"
19 #include <luci/IR/CircleNodes.h>
24 * NOTE TF's fusedBatchNorm is converted to mul and add of Circle.
28 * [CircleTransposeConv]
35 * [CircleTransposeConv]
37 bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv)
39 // check whether it has bias or not. This optimization works only if it doesn't.
40 auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
44 // get weight of tconv
45 auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
48 if (filter->dtype() != loco::DataType::FLOAT32)
52 auto tconv_output = loco::succs(tconv);
53 assert(tconv_output.size() == 1);
54 auto mul = dynamic_cast<luci::CircleMul *>(*tconv_output.begin());
57 if (mul->dtype() != loco::DataType::FLOAT32)
61 auto mul_output = loco::succs(mul);
62 assert(mul_output.size() == 1);
63 auto add = dynamic_cast<luci::CircleAdd *>(*mul_output.begin());
66 if (add->dtype() != loco::DataType::FLOAT32)
68 if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
69 add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
72 // get scale of batchnorm
73 auto scale = dynamic_cast<luci::CircleConst *>(mul->y());
77 // scale dim(0) == tconv filter channel dim
78 if (filter->rank() != 4)
80 auto filter_channel_dim = filter->dim(3).value();
81 if (scale->rank() != 1)
83 auto scale_dim = scale->dim(0).value();
84 if (filter_channel_dim != scale_dim)
87 // get shift of batchnorm
88 auto shift = dynamic_cast<luci::CircleConst *>(add->y());
92 // shift dim(0) == tconv filter channel dim
93 if (shift->rank() != 1)
95 auto shift_dim = shift->dim(0).value();
96 if (filter_channel_dim != shift_dim)
99 // filter weight = filter weight * mul(scale) + add(shift)
100 uint32_t filter_batch_dim = filter->dim(0).value();
101 uint32_t filter_height_dim = filter->dim(1).value();
102 uint32_t filter_width_dim = filter->dim(2).value();
103 for (uint32_t c = 0; c < filter_channel_dim; c++)
105 for (uint32_t n = 0; n < filter_batch_dim; n++)
107 for (uint32_t h = 0; h < filter_height_dim; h++)
109 for (uint32_t w = 0; w < filter_width_dim; w++)
111 uint32_t offset = n * filter_height_dim * filter_width_dim * filter_channel_dim +
112 h * filter_width_dim * filter_channel_dim + w * filter_channel_dim + c;
113 filter->at<loco::DataType::FLOAT32>(offset) *= scale->at<loco::DataType::FLOAT32>(c);
119 // fuse shift with transposed conv
122 if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
124 // separate relu op from add op
125 auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
126 relu->features(tconv);
129 replace(add).with(relu);
133 replace(add).with(tconv);
144 bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
146 bool changed = false;
147 for (auto node : loco::active_nodes(loco::output_nodes(g)))
149 auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
153 changed |= fused_batch_norm_with_tconv(tconv);