2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseAddWithTConvPass.h"
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Profile/CircleNodeOrigin.h>
25 * Fuse Add to TransposeConv if possible
29 * [CircleConst] [CircleTransposeConv]
38 * [CircleTransposeConv] [CircleAdd]
40 * ([CircleRelu/Relu6])
43 * Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
45 bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
47 // skip if tconv has fused activation
48 if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
50 // check whether it has bias or not. This optimization works only if it doesn't.
51 auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
55 // get weight of tconv
56 auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
59 if (filter->dtype() != loco::DataType::FLOAT32)
63 auto tconv_output = loco::succs(tconv);
64 assert(tconv_output.size() == 1);
65 auto add = dynamic_cast<luci::CircleAdd *>(*tconv_output.begin());
68 if (add->dtype() != loco::DataType::FLOAT32)
70 if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
71 add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
72 add->fusedActivationFunction() != luci::FusedActFunc::RELU)
76 luci::CircleConst *addition = nullptr;
77 if (add->x() == tconv)
78 addition = dynamic_cast<luci::CircleConst *>(add->y());
80 addition = dynamic_cast<luci::CircleConst *>(add->x());
85 // addition dim(0) == tconv filter channel dim
86 if (addition->rank() != 1)
88 auto addition_dim = addition->dim(0).value();
89 auto filter_channel_dim = filter->dim(0).value();
90 if (filter_channel_dim != addition_dim)
93 // fuse addition with transposed conv
94 tconv->bias(addition);
96 if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
98 auto name = addition->name();
99 assert(name.length() > 0);
100 // separate relu op from add op
101 auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
102 relu->features(tconv);
103 relu->name(name + "/Relu6");
104 luci::add_origin(relu, luci::get_origin(add));
107 replace(add).with(relu);
109 else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU)
111 auto name = addition->name();
112 assert(name.length() > 0);
113 // separate relu op from add op
114 auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
115 relu->features(tconv);
116 relu->name(name + "/Relu");
117 luci::add_origin(relu, luci::get_origin(add));
120 replace(add).with(relu);
124 replace(add).with(tconv);
128 luci::add_origin(tconv, luci::get_origin(add));
138 bool FuseAddWithTConvPass::run(loco::Graph *g)
140 bool changed = false;
141 for (auto node : loco::active_nodes(loco::output_nodes(g)))
143 auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
147 if (fuse_add_with_tconv(tconv))