Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / SubstituteSplitVToSplitPass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
18
19 #include "helpers/CreateCircleConst.h"
20
21 #include <luci/test/TestIOGraph.h>
22
23 #include <gtest/gtest.h>
24
25 namespace
26 {
27
28 using namespace luci::test;
29
30 const int N = 1;
31 const int C = 32;
32 const int H = 8;
33 const int W = 8;
34
35 /**
36  *  graph having SplitV operator
37  *
38  *                [CircleInput]
39  *                      |
40  *                [CircleSplitV]
41  *                     /  \
42  *      [CircleSplitVOut] [CircleSplitVOut]
43  *             |                   |
44  *       [CircleOutput]     [CircleOutput]
45  */
46 class SplitVGraphlet
47 {
48 public:
49   SplitVGraphlet() = default;
50
51 public:
52   void init(loco::Graph *g)
53   {
54     const std::vector<int32_t> splits{16, 16};
55     auto size_splits = luci::create_const_node(g, loco::DataType::S32, {2}, splits);
56
57     const std::vector<int32_t> dim{3};
58     auto split_dim = luci::create_const_node(g, loco::DataType::S32, {1}, dim);
59
60     _sv = g->nodes()->create<luci::CircleSplitV>();
61     _sv->size_splits(size_splits);
62     _sv->split_dim(split_dim);
63     _sv->num_split(2);
64     _sv->name("SplitV");
65
66     _svo1 = g->nodes()->create<luci::CircleSplitVOut>();
67     _svo1->input(_sv);
68     _svo1->index(0);
69     _svo1->name("SplitV0");
70
71     _svo2 = g->nodes()->create<luci::CircleSplitVOut>();
72     _svo2->input(_sv);
73     _svo2->index(1);
74     _svo2->name("SplitV1");
75   }
76
77 public:
78   luci::CircleSplitV *split_v() { return _sv; }
79   luci::CircleSplitVOut *split_vo1() { return _svo1; }
80   luci::CircleSplitVOut *split_vo2() { return _svo2; }
81
82 protected:
83   luci::CircleSplitV *_sv = nullptr;
84   luci::CircleSplitVOut *_svo1 = nullptr;
85   luci::CircleSplitVOut *_svo2 = nullptr;
86 };
87
88 class SplitVGraph : public TestIsGraphlet<1>, public TestOsGraphlet<2>, public SplitVGraphlet
89 {
90 public:
91   SplitVGraph() = default;
92
93   void init(void)
94   {
95     TestIsGraphlet<1>::init(g(), {{N, C, H, W}});
96     TestOsGraphlet<2>::init(g(), {{N, C, H / 2, W / 2}, {N, C, H / 2, W / 2}});
97     SplitVGraphlet::init(g());
98
99     split_v()->input(input(0));
100
101     output(0)->from(split_vo1());
102     output(1)->from(split_vo2());
103   }
104 };
105
106 class SubstituteSplitVToSplitPassTest : public ::testing::Test
107 {
108 public:
109   SplitVGraph g;
110   luci::SubstituteSplitVToSplitPass pass;
111 };
112
113 } // namespace
114
115 /**
116  *  Optimized graph looks like below.
117  *
118  *                [CircleInput]
119  *                      |
120  *                [CircleSplit]
121  *                     /  \
122  *      [CircleSplitOut] [CircleSplitOut]
123  *             |                 |
124  *       [CircleOutput]   [CircleOutput]
125  */
126 TEST_F(SubstituteSplitVToSplitPassTest, simple_test)
127 {
128   g.init();
129
130   auto ret = pass.run(g.g());
131   EXPECT_EQ(true, ret);
132
133   auto so1 = dynamic_cast<luci::CircleSplitOut *>(g.output(0)->from());
134   EXPECT_NE(nullptr, so1);
135
136   auto so2 = dynamic_cast<luci::CircleSplitOut *>(g.output(1)->from());
137   EXPECT_NE(nullptr, so2);
138
139   EXPECT_EQ(so1->input(), so2->input());
140
141   auto s = dynamic_cast<luci::CircleSplit *>(so1->input());
142   EXPECT_NE(nullptr, s);
143
144   auto input = dynamic_cast<luci::CircleInput *>(s->input());
145   EXPECT_NE(nullptr, input);
146 }
147
148 TEST_F(SubstituteSplitVToSplitPassTest, wrong_condition_NEG)
149 {
150   g.init();
151
152   g.split_v()->num_split(3); // Wrong num_split
153   auto ret = pass.run(g.g());
154
155   EXPECT_EQ(false, ret);
156 }