Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseActivationFunctionPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseActivationFunctionPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeMixins.h>
21 #include <luci/IR/CircleOpcode.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23
24 namespace luci
25 {
26
27 bool fuse_activation_function(luci::CircleNode *node)
28 {
29   auto preds = loco::preds(node);
30   assert(preds.size() == 1);
31
32   auto pred_node = static_cast<luci::CircleNode *>(*preds.begin());
33   if (loco::succs(pred_node).size() != 1)
34     return false;
35
36   auto node_with_fused_act =
37     dynamic_cast<luci::CircleNodeMixin<luci::CircleNodeTrait::FusedActFunc> *>(pred_node);
38   if (node_with_fused_act == nullptr)
39     return false;
40
41   // TODO remove this work-around
42   // This will skip fuse for concat as luci-interpreter doesn't support this yet
43   if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr)
44     return false;
45
46   auto fused_act = node_with_fused_act->fusedActivationFunction();
47
48   luci::FusedActFunc target_func = luci::FusedActFunc::UNDEFINED;
49
50   auto opcode = node->opcode();
51   if (opcode == luci::CircleOpcode::RELU)
52   {
53     if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU)
54       target_func = luci::FusedActFunc::RELU;
55     else if (fused_act == luci::FusedActFunc::RELU6)
56       target_func = luci::FusedActFunc::RELU6;
57     else
58       return false;
59   }
60   else if (opcode == luci::CircleOpcode::RELU6)
61   {
62     if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU ||
63         fused_act == luci::FusedActFunc::RELU6)
64       target_func = luci::FusedActFunc::RELU6;
65     else
66       return false;
67   }
68   else if (opcode == luci::CircleOpcode::RELU_N1_TO_1)
69   {
70     if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU_N1_TO_1)
71       target_func = luci::FusedActFunc::RELU_N1_TO_1;
72     else
73       return false;
74   }
75   else if (opcode == luci::CircleOpcode::TANH)
76   {
77     if (fused_act == luci::FusedActFunc::NONE)
78       target_func = luci::FusedActFunc::TANH;
79     else
80       return false;
81   }
82   else
83     return false;
84
85   node_with_fused_act->fusedActivationFunction(target_func);
86   luci::add_origin(pred_node, luci::get_origin(node));
87   loco::replace(node).with(pred_node);
88
89   node->drop();
90
91   return true;
92 }
93
94 bool FuseActivationFunctionPass::run(loco::Graph *g)
95 {
96   bool changed = false;
97   for (auto node : loco::active_nodes(loco::output_nodes(g)))
98   {
99     auto circle_node = static_cast<luci::CircleNode *>(node);
100     auto opcode = circle_node->opcode();
101     if (opcode == luci::CircleOpcode::RELU || opcode == luci::CircleOpcode::RELU6 ||
102         opcode == luci::CircleOpcode::RELU_N1_TO_1 || opcode == luci::CircleOpcode::TANH)
103     {
104       if (fuse_activation_function(circle_node))
105         changed = true;
106     }
107   }
108
109   return changed;
110 }
111
112 } // namespace luci