d83973cd5aeed671c98cb4023904de1aedf7a681
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseActivationFunctionPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseActivationFunctionPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeMixins.h>
21 #include <luci/IR/CircleOpcode.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23
24 namespace luci
25 {
26
27 bool fuse_activation_function(luci::CircleNode *node)
28 {
29   auto preds = loco::preds(node);
30   assert(preds.size() == 1);
31
32   auto pred_node = static_cast<luci::CircleNode *>(*preds.begin());
33   if (loco::succs(pred_node).size() != 1)
34     return false;
35
36   auto node_with_fused_act =
37     dynamic_cast<luci::CircleNodeMixin<luci::CircleNodeTrait::FusedActFunc> *>(pred_node);
38   if (node_with_fused_act == nullptr)
39     return false;
40
41   // TODO remove this work-around
42   // This will skip fuse for concat as luci-interpreter doesn't support this yet
43   if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr)
44     return false;
45
46   auto fused_act = node_with_fused_act->fusedActivationFunction();
47
48   luci::FusedActFunc target_func = luci::FusedActFunc::UNDEFINED;
49
50   auto opcode = node->opcode();
51   if (opcode == luci::CircleOpcode::RELU)
52   {
53     if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU)
54       target_func = luci::FusedActFunc::RELU;
55     else if (fused_act == luci::FusedActFunc::RELU6)
56       target_func = luci::FusedActFunc::RELU6;
57     else
58       return false;
59   }
60   else if (opcode == luci::CircleOpcode::RELU6)
61   {
62     if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU ||
63         fused_act == luci::FusedActFunc::RELU6)
64       target_func = luci::FusedActFunc::RELU6;
65     else
66       return false;
67   }
68   else if (opcode == luci::CircleOpcode::RELU_N1_TO_1)
69   {
70     if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU_N1_TO_1)
71       target_func = luci::FusedActFunc::RELU_N1_TO_1;
72     else
73       return false;
74   }
75   else
76     return false;
77
78   node_with_fused_act->fusedActivationFunction(target_func);
79   luci::add_origin(pred_node, luci::get_origin(node));
80   loco::replace(node).with(pred_node);
81
82   node->drop();
83
84   return true;
85 }
86
87 bool FuseActivationFunctionPass::run(loco::Graph *g)
88 {
89   bool changed = false;
90   for (auto node : loco::active_nodes(loco::output_nodes(g)))
91   {
92     auto circle_node = static_cast<luci::CircleNode *>(node);
93     auto opcode = circle_node->opcode();
94     // TANH is not supported as CONV fused with TANH is not supported in luci-interpreter
95     if (opcode == luci::CircleOpcode::RELU || opcode == luci::CircleOpcode::RELU6 ||
96         opcode == luci::CircleOpcode::RELU_N1_TO_1)
97     {
98       if (fuse_activation_function(circle_node))
99         changed = true;
100     }
101   }
102
103   return changed;
104 }
105
106 } // namespace luci