2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/DecomposeHardSwishPass.h"
19 #include "helpers/NodeFiller.h"
20 #include "helpers/TypeMapper.h"
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/Profile/CircleNodeOrigin.h>
40 * [CircleNode] [CircleConst]
46 * \ [CircleRelu6] [CircleConst]
58 bool decompose_hardswish(luci::CircleHardSwish *hardswish)
63 if (hardswish->dtype() != loco::DataType::FLOAT32)
66 auto g = hardswish->graph();
68 auto name = hardswish->name();
69 assert(name.length() > 0);
71 // Create a const for CircleAdd operation
72 auto add_const = g->nodes()->create<luci::CircleConst>();
73 add_const->shape({}); // scalar
74 add_const->dtype(loco::DataType::FLOAT32);
76 add_const->size<loco::DataType::FLOAT32>(1);
77 add_const->at<loco::DataType::FLOAT32>(0) = 3.;
78 add_const->name(name + "/Add/const");
79 luci::add_origin(add_const, luci::get_origin(hardswish));
81 // Create an Add operation
82 auto add = g->nodes()->create<luci::CircleAdd>();
83 add->fusedActivationFunction(luci::FusedActFunc::NONE);
84 add->x(hardswish->features());
86 add->name(name + "/Add");
87 luci::add_origin(add, luci::get_origin(hardswish));
89 // Create a Relu6 operation
90 auto relu6 = g->nodes()->create<luci::CircleRelu6>();
92 relu6->name(name + "/Relu6");
93 luci::add_origin(relu6, luci::get_origin(hardswish));
95 // Create a const for CircleMul operation
96 auto mul_const = g->nodes()->create<luci::CircleConst>();
97 mul_const->shape({}); // scalar
98 mul_const->dtype(loco::DataType::FLOAT32);
100 mul_const->size<loco::DataType::FLOAT32>(1);
101 mul_const->at<loco::DataType::FLOAT32>(0) = 1. / 6.;
102 mul_const->name(name + "/Mul/const");
103 luci::add_origin(mul_const, luci::get_origin(hardswish));
105 // Create first Mul operation
106 auto mul1 = g->nodes()->create<luci::CircleMul>();
107 mul1->fusedActivationFunction(luci::FusedActFunc::NONE);
110 mul1->name(name + "/Mul1");
111 luci::add_origin(mul1, luci::get_origin(hardswish));
113 // Create second Mul operation
114 auto mul2 = g->nodes()->create<luci::CircleMul>();
115 mul2->fusedActivationFunction(luci::FusedActFunc::NONE);
116 mul2->x(hardswish->features());
118 mul2->name(name + "/Mul2");
119 luci::add_origin(mul2, luci::get_origin(hardswish));
121 replace(hardswish).with(mul2);
131 bool DecomposeHardSwishPass::run(loco::Graph *g)
133 bool changed = false;
135 for (auto node : loco::active_nodes(loco::output_nodes(g)))
137 if (auto hardswish = dynamic_cast<luci::CircleHardSwish *>(node))
139 if (decompose_hardswish(hardswish))