Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / RemoveRedundantTranspose.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "luci/Pass/RemoveRedundantTransposePass.h"
17
18 #include <luci/IR/CircleNodes.h>
19
20 #include <vector>
21
22 #include <gtest/gtest.h>
23
24 namespace
25 {
26
27 void setValue(luci::CircleConst *node, const std::vector<int> &v)
28 {
29   node->dtype(loco::DataType::S32);
30   node->size<loco::DataType::S32>(v.size());
31   node->rank(1);
32   node->dim(0).set(v.size());
33   for (int i = 0; i < v.size(); ++i)
34   {
35     node->at<loco::DataType::S32>(i) = v[i];
36   }
37 }
38
39 /**
40  *  Type1
41  *  BEFORE
42  *         |
43  *   [CircleNode]     [CircleConst]
44  *           \              /
45  *           [CircleTranspose]  [CircleConst]
46  *                   \              /
47  *                   [CircleTranspose]
48  *                           |
49  *
50  *  AFTER
51  *         |
52  *   [CircleNode]
53  *         |   Remove Both
54  *
55  * --------------------------------------------
56  *
57  *  Type2
58  *  BEFORE
59  *         |
60  *   [CircleNode]     [CircleConst]
61  *           \              /
62  *           [CircleTranspose]  [CircleConst]
63  *                   \               /
64  *                   [CircleTranspose]
65  *                           |
66  *
67  *  AFTER
68  *          |                 |
69  *    [CircleNode]      [CircleConst]
70  *           \               /
71  *           [CircleTranspose]
72  *                   |
73  *
74  */
75 void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1,
76                                 const std::vector<int32_t> &perm2)
77 {
78   assert(g);
79
80   auto input = g->nodes()->create<luci::CircleInput>();
81   auto graph_input = g->inputs()->create();
82   input->index(graph_input->index());
83
84   // Create perm1
85   auto perm1_node = g->nodes()->create<luci::CircleConst>();
86   setValue(perm1_node, perm1);
87
88   auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
89   transpose1->dtype(loco::DataType::FLOAT32);
90   transpose1->a(input);
91   transpose1->perm(perm1_node);
92
93   // Create perm2
94   auto perm2_node = g->nodes()->create<luci::CircleConst>();
95   setValue(perm2_node, perm2);
96
97   auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
98   transpose2->dtype(loco::DataType::FLOAT32);
99   transpose2->a(transpose1);
100   transpose2->perm(perm2_node);
101
102   // Output
103   auto output = g->nodes()->create<luci::CircleOutput>();
104   output->from(transpose2);
105   auto graph_output = g->outputs()->create();
106   output->index(graph_output->index());
107 }
108
109 } // namespace
110
111 TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1)
112 {
113   auto graph = loco::make_graph();
114   create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3});
115
116   luci::RemoveRedundantTransposePass pass;
117   while (pass.run(graph.get()))
118     ;
119   luci::CircleTranspose *transpose_node = nullptr;
120   for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
121   {
122     auto trans = dynamic_cast<luci::CircleTranspose *>(node);
123     if (not trans)
124       continue;
125     transpose_node = trans;
126     break;
127   }
128   // No transpose node is in graph.
129   ASSERT_EQ(nullptr, transpose_node);
130 }
131
132 TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
133 {
134   auto graph = loco::make_graph();
135   create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3});
136
137   luci::RemoveRedundantTransposePass pass;
138   while (pass.run(graph.get()))
139     ;
140   luci::CircleTranspose *transpose_node = nullptr;
141   for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
142   {
143     auto trans = dynamic_cast<luci::CircleTranspose *>(node);
144     if (not trans)
145       continue;
146     transpose_node = trans;
147     break;
148   }
149   // Just one transpose node, with updated perm constant.
150   ASSERT_NE(nullptr, transpose_node);
151   auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
152   ASSERT_EQ(1, perm->at<loco::DataType::S32>(0));
153   ASSERT_EQ(0, perm->at<loco::DataType::S32>(1));
154   ASSERT_EQ(3, perm->at<loco::DataType::S32>(2));
155   ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
156 }