852bc8b63a3a146b80acf441e0c1da9b7fd1d6d8
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseAddWithTConvPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseAddWithTConvPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Profile/CircleNodeOrigin.h>
21
22 namespace
23 {
24 /**
25  *  Fuse Add to TransposeConv if possible
26  *
27  *  BEFORE
28  *                     |
29  *   [CircleConst]  [CircleTransposeConv]
30  *               \     |
31  *             [CircleAdd]
32  *                  |
33  *
34  *  AFTER
35  *                  |
36  *   [CircleConst]  |
37  *             \    |
38  *         [CircleTransposeConv]   [CircleAdd]
39  *                  |
40  *          ([CircleRelu/Relu6])
41  *                  |
42  *
43  *  Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
44  */
45 bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
46 {
47   // check whether it has bias or not. This optimization works only if it doesn't.
48   auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
49   if (not bias)
50     return false;
51
52   // get weight of tconv
53   auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
54   if (not filter)
55     return false;
56   if (filter->dtype() != loco::DataType::FLOAT32)
57     return false;
58
59   // get add node
60   auto tconv_output = loco::succs(tconv);
61   assert(tconv_output.size() == 1);
62   auto add = dynamic_cast<luci::CircleAdd *>(*tconv_output.begin());
63   if (not add)
64     return false;
65   if (add->dtype() != loco::DataType::FLOAT32)
66     return false;
67   if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
68       add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
69       add->fusedActivationFunction() != luci::FusedActFunc::RELU)
70     return false;
71
72   // get addition
73   luci::CircleConst *addition = nullptr;
74   if (add->x() == tconv)
75     addition = dynamic_cast<luci::CircleConst *>(add->y());
76   else
77     addition = dynamic_cast<luci::CircleConst *>(add->x());
78
79   if (not addition)
80     return false;
81
82   // addition dim(0) == tconv filter channel dim
83   if (addition->rank() != 1)
84     return false;
85   auto addition_dim = addition->dim(0).value();
86   auto filter_channel_dim = filter->dim(0).value();
87   if (filter_channel_dim != addition_dim)
88     return false;
89
90   // fuse addition with transposed conv
91   tconv->bias(addition);
92
93   if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
94   {
95     auto name = addition->name();
96     assert(name.length() > 0);
97     // separate relu op from add op
98     auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
99     relu->features(tconv);
100     relu->name(name + "/Relu6");
101     luci::add_origin(relu, luci::get_origin(add));
102
103     // remove add node
104     replace(add).with(relu);
105   }
106   else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU)
107   {
108     auto name = addition->name();
109     assert(name.length() > 0);
110     // separate relu op from add op
111     auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
112     relu->features(tconv);
113     relu->name(name + "/Relu");
114     luci::add_origin(relu, luci::get_origin(add));
115
116     // remove add node
117     replace(add).with(relu);
118   }
119   else
120   {
121     replace(add).with(tconv);
122   }
123
124   // set origin
125   luci::add_origin(tconv, luci::get_origin(add));
126
127   return true;
128 }
129
130 } // namespace
131
132 namespace luci
133 {
134
135 bool FuseAddWithTConvPass::run(loco::Graph *g)
136 {
137   bool changed = false;
138   for (auto node : loco::active_nodes(loco::output_nodes(g)))
139   {
140     auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
141     if (not tconv)
142       continue;
143
144     if (fuse_add_with_tconv(tconv))
145       changed = true;
146   }
147
148   return changed;
149 }
150
151 } // namespace luci