2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include "luci/Pass/RemoveRedundantTransposePass.h"
18 #include <luci/IR/CircleNodes.h>
22 #include <gtest/gtest.h>
27 void setValue(luci::CircleConst *node, const std::vector<int> &v)
29 node->dtype(loco::DataType::S32);
30 node->size<loco::DataType::S32>(v.size());
32 node->dim(0).set(v.size());
33 for (int i = 0; i < v.size(); ++i)
35 node->at<loco::DataType::S32>(i) = v[i];
43 * [CircleNode] [CircleConst]
45 * [CircleTranspose] [CircleConst]
55 * --------------------------------------------
60 * [CircleNode] [CircleConst]
62 * [CircleTranspose] [CircleConst]
69 * [CircleNode] [CircleConst]
75 void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1,
76 const std::vector<int32_t> &perm2)
80 auto input = g->nodes()->create<luci::CircleInput>();
81 auto graph_input = g->inputs()->create();
82 input->index(graph_input->index());
85 auto perm1_node = g->nodes()->create<luci::CircleConst>();
86 setValue(perm1_node, perm1);
88 auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
89 transpose1->dtype(loco::DataType::FLOAT32);
91 transpose1->perm(perm1_node);
94 auto perm2_node = g->nodes()->create<luci::CircleConst>();
95 setValue(perm2_node, perm2);
97 auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
98 transpose2->dtype(loco::DataType::FLOAT32);
99 transpose2->a(transpose1);
100 transpose2->perm(perm2_node);
103 auto output = g->nodes()->create<luci::CircleOutput>();
104 output->from(transpose2);
105 auto graph_output = g->outputs()->create();
106 output->index(graph_output->index());
111 TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1)
113 auto graph = loco::make_graph();
114 create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3});
116 luci::RemoveRedundantTransposePass pass;
117 while (pass.run(graph.get()))
119 luci::CircleTranspose *transpose_node = nullptr;
120 for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
122 auto trans = dynamic_cast<luci::CircleTranspose *>(node);
125 transpose_node = trans;
128 // No transpose node is in graph.
129 ASSERT_EQ(nullptr, transpose_node);
132 TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
134 auto graph = loco::make_graph();
135 create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3});
137 luci::RemoveRedundantTransposePass pass;
138 while (pass.run(graph.get()))
140 luci::CircleTranspose *transpose_node = nullptr;
141 for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
143 auto trans = dynamic_cast<luci::CircleTranspose *>(node);
146 transpose_node = trans;
149 // Just one transpose node, with updated perm constant.
150 ASSERT_NE(nullptr, transpose_node);
151 auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
152 ASSERT_EQ(1, perm->at<loco::DataType::S32>(0));
153 ASSERT_EQ(0, perm->at<loco::DataType::S32>(1));
154 ASSERT_EQ(3, perm->at<loco::DataType::S32>(2));
155 ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));