2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseActivationFunctionPass.h"
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeMixins.h>
21 #include <luci/IR/CircleOpcode.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
27 bool fuse_activation_function(luci::CircleNode *node)
29 auto preds = loco::preds(node);
30 assert(preds.size() == 1);
32 auto pred_node = static_cast<luci::CircleNode *>(*preds.begin());
33 if (loco::succs(pred_node).size() != 1)
36 auto node_with_fused_act =
37 dynamic_cast<luci::CircleNodeMixin<luci::CircleNodeTrait::FusedActFunc> *>(pred_node);
38 if (node_with_fused_act == nullptr)
41 // TODO remove this work-around
42 // This will skip fuse for concat as luci-interpreter doesn't support this yet
43 if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr)
46 auto fused_act = node_with_fused_act->fusedActivationFunction();
48 luci::FusedActFunc target_func = luci::FusedActFunc::UNDEFINED;
50 auto opcode = node->opcode();
51 if (opcode == luci::CircleOpcode::RELU)
53 if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU)
54 target_func = luci::FusedActFunc::RELU;
55 else if (fused_act == luci::FusedActFunc::RELU6)
56 target_func = luci::FusedActFunc::RELU6;
60 else if (opcode == luci::CircleOpcode::RELU6)
62 if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU ||
63 fused_act == luci::FusedActFunc::RELU6)
64 target_func = luci::FusedActFunc::RELU6;
68 else if (opcode == luci::CircleOpcode::RELU_N1_TO_1)
70 if (fused_act == luci::FusedActFunc::NONE || fused_act == luci::FusedActFunc::RELU_N1_TO_1)
71 target_func = luci::FusedActFunc::RELU_N1_TO_1;
75 else if (opcode == luci::CircleOpcode::TANH)
77 if (fused_act == luci::FusedActFunc::NONE)
78 target_func = luci::FusedActFunc::TANH;
85 node_with_fused_act->fusedActivationFunction(target_func);
86 luci::add_origin(pred_node, luci::get_origin(node));
87 loco::replace(node).with(pred_node);
94 bool FuseActivationFunctionPass::run(loco::Graph *g)
97 for (auto node : loco::active_nodes(loco::output_nodes(g)))
99 auto circle_node = static_cast<luci::CircleNode *>(node);
100 auto opcode = circle_node->opcode();
101 if (opcode == luci::CircleOpcode::RELU || opcode == luci::CircleOpcode::RELU6 ||
102 opcode == luci::CircleOpcode::RELU_N1_TO_1 || opcode == luci::CircleOpcode::TANH)
104 if (fuse_activation_function(circle_node))