Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseAddWithFullyConnectedPass.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Service/Nodes/CircleConst.h>
21 #include <luci/Profile/CircleNodeOrigin.h>
22
23 namespace
24 {
25 /**
26  *  Fuse Add to FullyConnected if the added value is a channel(last dimension)-wise constant
27  *
28  *  BEFORE
29  *                |
30  *      [CircleFullyConnected]
31  *                |
32  *           [CircleAdd]
33  *                |
34  *
35  *  AFTER
36  *                |
37  *       [CircleFullyConnected]   [CircleAdd] (dead)
38  *                |
39  *
40  */
41 bool fuse_add_with_fc(luci::CircleFullyConnected *fc)
42 {
43   if (not fc)
44     return false;
45
46   if (fc->dtype() != loco::DataType::FLOAT32)
47     return false;
48
49   if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE)
50     return false;
51
52   auto weights = dynamic_cast<luci::CircleConst *>(fc->weights());
53   if (not weights)
54     return false;
55
56   // Get add node
57   auto fc_output = loco::succs(fc);
58   if (fc_output.size() != 1)
59     return false;
60
61   auto add = dynamic_cast<luci::CircleAdd *>(*fc_output.begin());
62   if (not add)
63     return false;
64   if (add->dtype() != loco::DataType::FLOAT32)
65     return false;
66
67   // Get addition
68   auto addition = add->x() == fc ? dynamic_cast<luci::CircleConst *>(add->y())
69                                  : dynamic_cast<luci::CircleConst *>(add->x());
70
71   // Non-const addition
72   if (not addition)
73     return false;
74
75   auto rank = addition->rank();
76   // TODO Support scalar addition
77   if (rank == 0)
78     return false;
79
80   for (uint32_t i = 0; i < rank - 1; i++)
81   {
82     if (addition->dim(i).value() != 1)
83       return false;
84   }
85   // Check the last dimesion of addition is the same with the number of neurons of FC
86   if (not(addition->dim(rank - 1) == weights->dim(0)))
87     return false;
88
89   auto fused_bias = luci::clone(addition);
90
91   // Add existing bias values
92   if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()))
93   {
94     assert(const_bias->dtype() == loco::DataType::FLOAT32);
95
96     auto bias_size = fused_bias->size<loco::DataType::FLOAT32>();
97     assert(bias_size == const_bias->size<loco::DataType::FLOAT32>());
98     for (uint32_t i = 0; i < bias_size; i++)
99       fused_bias->at<loco::DataType::FLOAT32>(i) += const_bias->at<loco::DataType::FLOAT32>(i);
100   }
101
102   fc->bias(fused_bias);
103   fc->fusedActivationFunction(add->fusedActivationFunction());
104
105   // set origin
106   luci::add_origin(fc, luci::get_origin(add));
107
108   replace(add).with(fc);
109
110   return true;
111 }
112
113 } // namespace
114
115 namespace luci
116 {
117
118 bool FuseAddWithFullyConnectedPass::run(loco::Graph *g)
119 {
120   bool changed = false;
121   for (auto node : loco::active_nodes(loco::output_nodes(g)))
122   {
123     auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
124     if (not fc)
125       continue;
126
127     if (fuse_add_with_fc(fc))
128       changed = true;
129   }
130
131   return changed;
132 }
133
134 } // namespace luci