2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
19 #include "helpers/CreateCircleConst.h"
21 #include <luci/test/TestIOGraph.h>
23 #include <gtest/gtest.h>
28 using namespace luci::test;
36 * graph having SplitV operator
42 * [CircleSplitVOut] [CircleSplitVOut]
44 * [CircleOutput] [CircleOutput]
49 SplitVGraphlet() = default;
52 void init(loco::Graph *g)
54 const std::vector<int32_t> splits{16, 16};
55 auto size_splits = luci::create_const_node(g, loco::DataType::S32, {2}, splits);
57 const std::vector<int32_t> dim{3};
58 auto split_dim = luci::create_const_node(g, loco::DataType::S32, {1}, dim);
60 _sv = g->nodes()->create<luci::CircleSplitV>();
61 _sv->size_splits(size_splits);
62 _sv->split_dim(split_dim);
66 _svo1 = g->nodes()->create<luci::CircleSplitVOut>();
69 _svo1->name("SplitV0");
71 _svo2 = g->nodes()->create<luci::CircleSplitVOut>();
74 _svo2->name("SplitV1");
78 luci::CircleSplitV *split_v() { return _sv; }
79 luci::CircleSplitVOut *split_vo1() { return _svo1; }
80 luci::CircleSplitVOut *split_vo2() { return _svo2; }
83 luci::CircleSplitV *_sv = nullptr;
84 luci::CircleSplitVOut *_svo1 = nullptr;
85 luci::CircleSplitVOut *_svo2 = nullptr;
88 class SplitVGraph : public TestIsGraphlet<1>, public TestOsGraphlet<2>, public SplitVGraphlet
91 SplitVGraph() = default;
95 TestIsGraphlet<1>::init(g(), {{N, C, H, W}});
96 TestOsGraphlet<2>::init(g(), {{N, C, H / 2, W / 2}, {N, C, H / 2, W / 2}});
97 SplitVGraphlet::init(g());
99 split_v()->input(input(0));
101 output(0)->from(split_vo1());
102 output(1)->from(split_vo2());
106 class SubstituteSplitVToSplitPassTest : public ::testing::Test
110 luci::SubstituteSplitVToSplitPass pass;
116 * Optimized graph looks like below.
122 * [CircleSplitOut] [CircleSplitOut]
124 * [CircleOutput] [CircleOutput]
126 TEST_F(SubstituteSplitVToSplitPassTest, simple_test)
130 auto ret = pass.run(g.g());
131 EXPECT_EQ(true, ret);
133 auto so1 = dynamic_cast<luci::CircleSplitOut *>(g.output(0)->from());
134 EXPECT_NE(nullptr, so1);
136 auto so2 = dynamic_cast<luci::CircleSplitOut *>(g.output(1)->from());
137 EXPECT_NE(nullptr, so2);
139 EXPECT_EQ(so1->input(), so2->input());
141 auto s = dynamic_cast<luci::CircleSplit *>(so1->input());
142 EXPECT_NE(nullptr, s);
144 auto input = dynamic_cast<luci::CircleInput *>(s->input());
145 EXPECT_NE(nullptr, input);
148 TEST_F(SubstituteSplitVToSplitPassTest, wrong_condition_NEG)
152 g.split_v()->num_split(3); // Wrong num_split
153 auto ret = pass.run(g.g());
155 EXPECT_EQ(false, ret);