2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
19 #include <luci/test/TestIOGraph.h>
21 #include <gtest/gtest.h>
26 using namespace luci::test;
33 // Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
35 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
36 const std::vector<uint32_t> &shape,
37 const std::vector<T> &values)
39 auto node = g->nodes()->create<luci::CircleConst>();
41 node->rank(shape.size());
44 for (uint32_t i = 0; i < shape.size(); ++i)
46 node->dim(i) = shape.at(i);
49 node->shape_status(luci::ShapeStatus::VALID);
51 #define INIT_VALUES(DT) \
53 node->size<DT>(size); \
54 for (uint32_t i = 0; i < values.size(); ++i) \
55 node->at<DT>(i) = values[i]; \
60 case loco::DataType::U8:
61 INIT_VALUES(loco::DataType::U8);
63 case loco::DataType::S16:
64 INIT_VALUES(loco::DataType::S16);
66 case loco::DataType::S32:
67 INIT_VALUES(loco::DataType::S32);
69 case loco::DataType::FLOAT32:
70 INIT_VALUES(loco::DataType::FLOAT32)
73 INTERNAL_EXN("create_const_node called with unsupported type");
79 * graph having SplitV operator
85 * [CircleSplitVOut] [CircleSplitVOut]
87 * [CircleOutput] [CircleOutput]
92 SplitVGraphlet() = default;
95 void init(loco::Graph *g)
97 const std::vector<int32_t> splits{16, 16};
98 auto size_splits = create_const_node(g, loco::DataType::S32, {2}, splits);
100 const std::vector<int32_t> dim{3};
101 auto split_dim = create_const_node(g, loco::DataType::S32, {1}, dim);
103 _sv = g->nodes()->create<luci::CircleSplitV>();
104 _sv->size_splits(size_splits);
105 _sv->split_dim(split_dim);
109 _svo1 = g->nodes()->create<luci::CircleSplitVOut>();
112 _svo1->name("SplitV0");
114 _svo2 = g->nodes()->create<luci::CircleSplitVOut>();
117 _svo2->name("SplitV1");
121 luci::CircleSplitV *split_v() { return _sv; }
122 luci::CircleSplitVOut *split_vo1() { return _svo1; }
123 luci::CircleSplitVOut *split_vo2() { return _svo2; }
126 luci::CircleSplitV *_sv = nullptr;
127 luci::CircleSplitVOut *_svo1 = nullptr;
128 luci::CircleSplitVOut *_svo2 = nullptr;
131 class SplitVGraph : public TestIsGraphlet<1>, public TestOsGraphlet<2>, public SplitVGraphlet
134 SplitVGraph() = default;
138 TestIsGraphlet<1>::init(g(), {{N, C, H, W}});
139 TestOsGraphlet<2>::init(g(), {{N, C, H / 2, W / 2}, {N, C, H / 2, W / 2}});
140 SplitVGraphlet::init(g());
142 split_v()->input(input(0));
144 output(0)->from(split_vo1());
145 output(1)->from(split_vo2());
149 class SubstituteSplitVToSplitPassTest : public ::testing::Test
153 luci::SubstituteSplitVToSplitPass pass;
159 * Optimized graph looks like below.
165 * [CircleSplitOut] [CircleSplitOut]
167 * [CircleOutput] [CircleOutput]
169 TEST_F(SubstituteSplitVToSplitPassTest, simple_test)
173 auto ret = pass.run(g.g());
174 EXPECT_EQ(true, ret);
176 auto so1 = dynamic_cast<luci::CircleSplitOut *>(g.output(0)->from());
177 EXPECT_NE(nullptr, so1);
179 auto so2 = dynamic_cast<luci::CircleSplitOut *>(g.output(1)->from());
180 EXPECT_NE(nullptr, so2);
182 EXPECT_EQ(so1->input(), so2->input());
184 auto s = dynamic_cast<luci::CircleSplit *>(so1->input());
185 EXPECT_NE(nullptr, s);
187 auto input = dynamic_cast<luci::CircleInput *>(s->input());
188 EXPECT_NE(nullptr, input);
191 TEST_F(SubstituteSplitVToSplitPassTest, wrong_condition_NEG)
195 g.split_v()->num_split(3); // Wrong num_split
196 auto ret = pass.run(g.g());
198 EXPECT_EQ(false, ret);