Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / SubstitutePackToReshapePass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "luci/Pass/SubstitutePackToReshapePass.h"
17
18 #include <luci/IR/CircleNodes.h>
19
20 #include <gtest/gtest.h>
21
22 namespace
23 {
24
25 /**
26  *           BEFORE
27  *             |
28  *        [CircleNode]
29  *             |
30  *        [CirclePack]
31  *             |
32  *        [CircleNode]
33  *             |
34  *
35  *           AFTER
36  *      |
37  * [CircleNode]  [CircleConst]
38  *       \             /
39  *       [CircleReshape]
40  *             |
41  *        [CircleNode]
42  *             |
43  *
44  */
45 void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list<uint32_t> shape,
46                                        int32_t axis)
47 {
48   assert(g);
49
50   // Input Create.
51   auto input = g->nodes()->create<luci::CircleInput>();
52   auto graph_input = g->inputs()->create();
53   input->index(graph_input->index());
54   input->shape_status(luci::ShapeStatus::VALID);
55   input->rank(shape.size());
56   input->shape(shape);
57
58   // Pack Node create.
59   auto pack = g->nodes()->create<luci::CirclePack>(1);
60   pack->values(0, input);
61   pack->axis(axis);
62
63   // Output Connect.
64   auto output = g->nodes()->create<luci::CircleOutput>();
65   output->from(pack);
66   auto graph_output = g->outputs()->create();
67   output->index(graph_output->index());
68
69   return;
70 }
71
72 } // namespace
73
74 TEST(SubstitutePackToReshapePass, simple_case)
75 {
76   auto graph = loco::make_graph();
77   create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, 0);
78   luci::SubstitutePackToReshapePass pass;
79   while (pass.run(graph.get()))
80     ;
81   luci::CircleReshape *reshape_node = nullptr;
82   luci::CirclePack *pack_node = nullptr;
83   for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
84   {
85     if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
86       reshape_node = reshape;
87     else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
88       pack_node = pack;
89   }
90   ASSERT_NE(nullptr, reshape_node);
91   ASSERT_EQ(nullptr, pack_node);
92   auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
93   ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
94   ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(1));
95   ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(2));
96   ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(3));
97   ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(4));
98 }
99
100 TEST(SubstitutePackToReshapePass, simple_case_neg_axis)
101 {
102   auto graph = loco::make_graph();
103   create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, -1);
104   luci::SubstitutePackToReshapePass pass;
105   while (pass.run(graph.get()))
106     ;
107   luci::CircleReshape *reshape_node = nullptr;
108   luci::CirclePack *pack_node = nullptr;
109   for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
110   {
111     if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
112       reshape_node = reshape;
113     else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
114       pack_node = pack;
115   }
116   ASSERT_NE(nullptr, reshape_node);
117   ASSERT_EQ(nullptr, pack_node);
118   auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
119   ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
120   ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(1));
121   ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(2));
122   ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(3));
123   ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(4));
124 }