6e30103f9681bc188c83e96924a8ddf0b13bf1c5
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / SubstituteSplitVToSplitPass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
18
19 #include <luci/test/TestIOGraph.h>
20
21 #include <gtest/gtest.h>
22
23 namespace
24 {
25
26 using namespace luci::test;
27
28 const int N = 1;
29 const int C = 32;
30 const int H = 8;
31 const int W = 8;
32
33 // Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
34 template <typename T>
35 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
36                                      const std::vector<uint32_t> &shape,
37                                      const std::vector<T> &values)
38 {
39   auto node = g->nodes()->create<luci::CircleConst>();
40   node->dtype(dtype);
41   node->rank(shape.size());
42
43   uint32_t size = 1;
44   for (uint32_t i = 0; i < shape.size(); ++i)
45   {
46     node->dim(i) = shape.at(i);
47     size *= shape.at(i);
48   }
49   node->shape_status(luci::ShapeStatus::VALID);
50
51 #define INIT_VALUES(DT)                          \
52   {                                              \
53     node->size<DT>(size);                        \
54     for (uint32_t i = 0; i < values.size(); ++i) \
55       node->at<DT>(i) = values[i];               \
56   }
57
58   switch (dtype)
59   {
60     case loco::DataType::U8:
61       INIT_VALUES(loco::DataType::U8);
62       break;
63     case loco::DataType::S16:
64       INIT_VALUES(loco::DataType::S16);
65       break;
66     case loco::DataType::S32:
67       INIT_VALUES(loco::DataType::S32);
68       break;
69     case loco::DataType::FLOAT32:
70       INIT_VALUES(loco::DataType::FLOAT32)
71       break;
72     default:
73       INTERNAL_EXN("create_const_node called with unsupported type");
74       break;
75   }
76   return node;
77 }
78 /**
79  *  graph having SplitV operator
80  *
81  *                [CircleInput]
82  *                      |
83  *                [CircleSplitV]
84  *                     /  \
85  *      [CircleSplitVOut] [CircleSplitVOut]
86  *             |                   |
87  *       [CircleOutput]     [CircleOutput]
88  */
89 class SplitVGraphlet
90 {
91 public:
92   SplitVGraphlet() = default;
93
94 public:
95   void init(loco::Graph *g)
96   {
97     const std::vector<int32_t> splits{16, 16};
98     auto size_splits = create_const_node(g, loco::DataType::S32, {2}, splits);
99
100     const std::vector<int32_t> dim{3};
101     auto split_dim = create_const_node(g, loco::DataType::S32, {1}, dim);
102
103     _sv = g->nodes()->create<luci::CircleSplitV>();
104     _sv->size_splits(size_splits);
105     _sv->split_dim(split_dim);
106     _sv->num_split(2);
107     _sv->name("SplitV");
108
109     _svo1 = g->nodes()->create<luci::CircleSplitVOut>();
110     _svo1->input(_sv);
111     _svo1->index(0);
112     _svo1->name("SplitV0");
113
114     _svo2 = g->nodes()->create<luci::CircleSplitVOut>();
115     _svo2->input(_sv);
116     _svo2->index(1);
117     _svo2->name("SplitV1");
118   }
119
120 public:
121   luci::CircleSplitV *split_v() { return _sv; }
122   luci::CircleSplitVOut *split_vo1() { return _svo1; }
123   luci::CircleSplitVOut *split_vo2() { return _svo2; }
124
125 protected:
126   luci::CircleSplitV *_sv = nullptr;
127   luci::CircleSplitVOut *_svo1 = nullptr;
128   luci::CircleSplitVOut *_svo2 = nullptr;
129 };
130
131 class SplitVGraph : public TestIsGraphlet<1>, public TestOsGraphlet<2>, public SplitVGraphlet
132 {
133 public:
134   SplitVGraph() = default;
135
136   void init(void)
137   {
138     TestIsGraphlet<1>::init(g(), {{N, C, H, W}});
139     TestOsGraphlet<2>::init(g(), {{N, C, H / 2, W / 2}, {N, C, H / 2, W / 2}});
140     SplitVGraphlet::init(g());
141
142     split_v()->input(input(0));
143
144     output(0)->from(split_vo1());
145     output(1)->from(split_vo2());
146   }
147 };
148
149 class SubstituteSplitVToSplitPassTest : public ::testing::Test
150 {
151 public:
152   SplitVGraph g;
153   luci::SubstituteSplitVToSplitPass pass;
154 };
155
156 } // namespace
157
158 /**
159  *  Optimized graph looks like below.
160  *
161  *                [CircleInput]
162  *                      |
163  *                [CircleSplit]
164  *                     /  \
165  *      [CircleSplitOut] [CircleSplitOut]
166  *             |                 |
167  *       [CircleOutput]   [CircleOutput]
168  */
169 TEST_F(SubstituteSplitVToSplitPassTest, simple_test)
170 {
171   g.init();
172
173   auto ret = pass.run(g.g());
174   EXPECT_EQ(true, ret);
175
176   auto so1 = dynamic_cast<luci::CircleSplitOut *>(g.output(0)->from());
177   EXPECT_NE(nullptr, so1);
178
179   auto so2 = dynamic_cast<luci::CircleSplitOut *>(g.output(1)->from());
180   EXPECT_NE(nullptr, so2);
181
182   EXPECT_EQ(so1->input(), so2->input());
183
184   auto s = dynamic_cast<luci::CircleSplit *>(so1->input());
185   EXPECT_NE(nullptr, s);
186
187   auto input = dynamic_cast<luci::CircleInput *>(s->input());
188   EXPECT_NE(nullptr, input);
189 }
190
191 TEST_F(SubstituteSplitVToSplitPassTest, wrong_condition_NEG)
192 {
193   g.init();
194
195   g.split_v()->num_split(3); // Wrong num_split
196   auto ret = pass.run(g.g());
197
198   EXPECT_EQ(false, ret);
199 }