Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseAddWithTConvPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseAddWithTConvPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Profile/CircleNodeOrigin.h>
21
22 namespace
23 {
24 /**
25  *  Fuse Add to TransposeConv if possible
26  *
27  *  BEFORE
28  *                     |
29  *   [CircleConst]  [CircleTransposeConv]
30  *               \     |
31  *             [CircleAdd]
32  *                  |
33  *
34  *  AFTER
35  *                  |
36  *   [CircleConst]  |
37  *             \    |
38  *         [CircleTransposeConv]   [CircleAdd]
39  *                  |
40  *          ([CircleRelu/Relu6])
41  *                  |
42  *
43  *  Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
44  */
45 bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
46 {
47   // skip if tconv has fused activation
48   if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
49     return false;
50   // check whether it has bias or not. This optimization works only if it doesn't.
51   auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
52   if (not bias)
53     return false;
54
55   // get weight of tconv
56   auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
57   if (not filter)
58     return false;
59   if (filter->dtype() != loco::DataType::FLOAT32)
60     return false;
61
62   // get add node
63   auto tconv_output = loco::succs(tconv);
64   assert(tconv_output.size() == 1);
65   auto add = dynamic_cast<luci::CircleAdd *>(*tconv_output.begin());
66   if (not add)
67     return false;
68   if (add->dtype() != loco::DataType::FLOAT32)
69     return false;
70   if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
71       add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
72       add->fusedActivationFunction() != luci::FusedActFunc::RELU)
73     return false;
74
75   // get addition
76   luci::CircleConst *addition = nullptr;
77   if (add->x() == tconv)
78     addition = dynamic_cast<luci::CircleConst *>(add->y());
79   else
80     addition = dynamic_cast<luci::CircleConst *>(add->x());
81
82   if (not addition)
83     return false;
84
85   // addition dim(0) == tconv filter channel dim
86   if (addition->rank() != 1)
87     return false;
88   auto addition_dim = addition->dim(0).value();
89   auto filter_channel_dim = filter->dim(0).value();
90   if (filter_channel_dim != addition_dim)
91     return false;
92
93   // fuse addition with transposed conv
94   tconv->bias(addition);
95
96   if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
97   {
98     auto name = addition->name();
99     assert(name.length() > 0);
100     // separate relu op from add op
101     auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
102     relu->features(tconv);
103     relu->name(name + "/Relu6");
104     luci::add_origin(relu, luci::get_origin(add));
105
106     // remove add node
107     replace(add).with(relu);
108   }
109   else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU)
110   {
111     auto name = addition->name();
112     assert(name.length() > 0);
113     // separate relu op from add op
114     auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
115     relu->features(tconv);
116     relu->name(name + "/Relu");
117     luci::add_origin(relu, luci::get_origin(add));
118
119     // remove add node
120     replace(add).with(relu);
121   }
122   else
123   {
124     replace(add).with(tconv);
125   }
126
127   // set origin
128   luci::add_origin(tconv, luci::get_origin(add));
129
130   return true;
131 }
132
133 } // namespace
134
135 namespace luci
136 {
137
138 bool FuseAddWithTConvPass::run(loco::Graph *g)
139 {
140   bool changed = false;
141   for (auto node : loco::active_nodes(loco::output_nodes(g)))
142   {
143     auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
144     if (not tconv)
145       continue;
146
147     if (fuse_add_with_tconv(tconv))
148       changed = true;
149   }
150
151   return changed;
152 }
153
154 } // namespace luci