2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Service/Nodes/CircleConst.h>
21 #include <luci/Profile/CircleNodeOrigin.h>
26 * Fuse Add to FullyConnected if the added value is a channel(last dimension)-wise constant
30 * [CircleFullyConnected]
37 * [CircleFullyConnected] [CircleAdd] (dead)
41 bool fuse_add_with_fc(luci::CircleFullyConnected *fc)
46 if (fc->dtype() != loco::DataType::FLOAT32)
49 if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE)
52 auto weights = dynamic_cast<luci::CircleConst *>(fc->weights());
57 auto fc_output = loco::succs(fc);
58 if (fc_output.size() != 1)
61 auto add = dynamic_cast<luci::CircleAdd *>(*fc_output.begin());
64 if (add->dtype() != loco::DataType::FLOAT32)
68 auto addition = add->x() == fc ? dynamic_cast<luci::CircleConst *>(add->y())
69 : dynamic_cast<luci::CircleConst *>(add->x());
75 auto rank = addition->rank();
76 // TODO Support scalar addition
80 for (uint32_t i = 0; i < rank - 1; i++)
82 if (addition->dim(i).value() != 1)
85 // Check the last dimesion of addition is the same with the number of neurons of FC
86 if (not(addition->dim(rank - 1) == weights->dim(0)))
89 auto fused_bias = luci::clone(addition);
91 // Add existing bias values
92 if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()))
94 assert(const_bias->dtype() == loco::DataType::FLOAT32);
96 auto bias_size = fused_bias->size<loco::DataType::FLOAT32>();
97 assert(bias_size == const_bias->size<loco::DataType::FLOAT32>());
98 for (uint32_t i = 0; i < bias_size; i++)
99 fused_bias->at<loco::DataType::FLOAT32>(i) += const_bias->at<loco::DataType::FLOAT32>(i);
102 fc->bias(fused_bias);
103 fc->fusedActivationFunction(add->fusedActivationFunction());
106 luci::add_origin(fc, luci::get_origin(add));
108 replace(add).with(fc);
118 bool FuseAddWithFullyConnectedPass::run(loco::Graph *g)
120 bool changed = false;
121 for (auto node : loco::active_nodes(loco::output_nodes(g)))
123 auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
127 if (fuse_add_with_fc(fc))