Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / DecomposeHardSwishPass.test.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/DecomposeHardSwishPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <gtest/gtest.h>
22
23 namespace
24 {
25
26 /**
27  *  HardSwish graph
28  *
29  *        [CircleInput]
30  *              |
31  *              |
32  *      [CircleHardSwish]
33  *              |
34  *              |
35  *        [CircleOutput]
36  */
37 struct HardSwishGraph
38 {
39   loco::Graph _g;
40   luci::CircleInput *_input = nullptr;
41   luci::CircleHardSwish *_hardswish = nullptr;
42   luci::CircleOutput *_output = nullptr;
43 };
44
45 class DecomposeHardSwishPass : public ::testing::Test
46 {
47 protected:
48   void MakeGraph()
49   {
50     const int N = 1;
51     const int H = 4;
52     const int W = 4;
53     const int C = 3;
54
55     // graph input and output
56     auto graph_input = _hardswish_g._g.inputs()->create();
57     auto graph_output = _hardswish_g._g.outputs()->create();
58
59     // CircleInput
60     _hardswish_g._input = _hardswish_g._g.nodes()->create<luci::CircleInput>();
61     _hardswish_g._input->index(graph_input->index());
62     _hardswish_g._input->shape({N, H, W, C});
63     _hardswish_g._input->dtype(loco::DataType::FLOAT32);
64     _hardswish_g._input->name("input");
65
66     // CircleHardSwish
67     _hardswish_g._hardswish = _hardswish_g._g.nodes()->create<luci::CircleHardSwish>();
68     _hardswish_g._hardswish->features(_hardswish_g._input);
69     _hardswish_g._hardswish->shape({N, H, W, C});
70     _hardswish_g._hardswish->dtype(loco::DataType::FLOAT32);
71     _hardswish_g._hardswish->name("hardswish");
72
73     // CircleOutput
74     _hardswish_g._output = _hardswish_g._g.nodes()->create<luci::CircleOutput>();
75     _hardswish_g._output->index(graph_output->index());
76     _hardswish_g._output->from(_hardswish_g._hardswish);
77     _hardswish_g._output->shape({N, H, W, C});
78     _hardswish_g._output->dtype(loco::DataType::FLOAT32);
79     _hardswish_g._output->name("output");
80   }
81
82   void MakeInt32Graph()
83   {
84     const int N = 1;
85     const int H = 4;
86     const int W = 4;
87     const int C = 3;
88
89     // graph input and output
90     auto graph_input = _hardswish_int32_g._g.inputs()->create();
91     auto graph_output = _hardswish_int32_g._g.outputs()->create();
92
93     // CircleInput
94     _hardswish_int32_g._input = _hardswish_int32_g._g.nodes()->create<luci::CircleInput>();
95     _hardswish_int32_g._input->index(graph_input->index());
96     _hardswish_int32_g._input->shape({N, H, W, C});
97     _hardswish_int32_g._input->dtype(loco::DataType::S32);
98     _hardswish_int32_g._input->name("input");
99
100     // CircleHardSwish
101     _hardswish_int32_g._hardswish = _hardswish_int32_g._g.nodes()->create<luci::CircleHardSwish>();
102     _hardswish_int32_g._hardswish->features(_hardswish_int32_g._input);
103     _hardswish_int32_g._hardswish->shape({N, H, W, C});
104     _hardswish_int32_g._hardswish->dtype(loco::DataType::S32);
105     _hardswish_int32_g._hardswish->name("hardswish");
106
107     // CircleOutput
108     _hardswish_int32_g._output = _hardswish_int32_g._g.nodes()->create<luci::CircleOutput>();
109     _hardswish_int32_g._output->index(graph_output->index());
110     _hardswish_int32_g._output->from(_hardswish_int32_g._hardswish);
111     _hardswish_int32_g._output->shape({N, H, W, C});
112     _hardswish_int32_g._output->dtype(loco::DataType::S32);
113     _hardswish_int32_g._output->name("output");
114   }
115
116   virtual void SetUp()
117   {
118     MakeGraph();
119     MakeInt32Graph();
120   }
121
122 protected:
123   luci::DecomposeHardSwishPass _pass;
124   HardSwishGraph _hardswish_g;
125   HardSwishGraph _hardswish_int32_g;
126 };
127
128 } // namespace
129
130 TEST_F(DecomposeHardSwishPass, name)
131 {
132   auto const name = _pass.name();
133   ASSERT_NE(nullptr, name);
134 }
135
136 /**
137  *  Decomposed graph looks like below.
138  *
139  *      [CircleInput]  [CircleConst]
140  *          |    \       /
141  *          |     \     /
142  *          |   [CircleAdd]
143  *          |        |
144  *          |        |
145  *          \  [CircleRelu6] [CircleConst]
146  *           \        \        /
147  *            \        \      /
148  *             \      [CircleMul]
149  *              \       /
150  *               \     /
151  *             [CircleMul]
152  *                  |
153  *                  |
154  *             [CircleOutput]
155  *
156  */
157 TEST_F(DecomposeHardSwishPass, simple_test)
158 {
159   auto ret = _pass.run(&_hardswish_g._g);
160   EXPECT_TRUE(ret);
161
162   auto mul2 = dynamic_cast<luci::CircleMul *>(_hardswish_g._output->from());
163   EXPECT_NE(nullptr, mul2);
164
165   auto input2 = dynamic_cast<luci::CircleInput *>(mul2->x());
166   EXPECT_NE(nullptr, input2);
167
168   auto mul1 = dynamic_cast<luci::CircleMul *>(mul2->y());
169   EXPECT_NE(nullptr, mul1);
170
171   auto relu6 = dynamic_cast<luci::CircleRelu6 *>(mul1->x());
172   EXPECT_NE(nullptr, relu6);
173
174   auto mul_const = dynamic_cast<luci::CircleConst *>(mul1->y());
175   EXPECT_NE(nullptr, mul_const);
176   EXPECT_FLOAT_EQ(1. / 6., mul_const->at<loco::DataType::FLOAT32>(0));
177
178   auto add = dynamic_cast<luci::CircleAdd *>(relu6->features());
179   EXPECT_NE(nullptr, add);
180
181   auto input1 = dynamic_cast<luci::CircleInput *>(add->x());
182   EXPECT_NE(nullptr, input1);
183
184   auto add_const = dynamic_cast<luci::CircleConst *>(add->y());
185   EXPECT_NE(nullptr, add_const);
186   EXPECT_FLOAT_EQ(3., add_const->at<loco::DataType::FLOAT32>(0));
187 }
188
189 TEST_F(DecomposeHardSwishPass, check_last_node)
190 {
191   auto ret = _pass.run(&_hardswish_g._g);
192   EXPECT_TRUE(ret);
193
194   auto hardswish = dynamic_cast<luci::CircleHardSwish *>(_hardswish_g._output->from());
195   EXPECT_EQ(nullptr, hardswish);
196 }
197
198 TEST_F(DecomposeHardSwishPass, wrong_condition_NEG)
199 {
200   auto ret = _pass.run(&_hardswish_int32_g._g);
201   EXPECT_FALSE(ret);
202
203   auto hardswish = dynamic_cast<luci::CircleHardSwish *>(_hardswish_g._output->from());
204   EXPECT_NE(nullptr, hardswish);
205 }