2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include "luci/Pass/SubstitutePackToReshapePass.h"
18 #include <luci/IR/CircleNodes.h>
20 #include <gtest/gtest.h>
37 * [CircleNode] [CircleConst]
45 void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list<uint32_t> shape,
51 auto input = g->nodes()->create<luci::CircleInput>();
52 auto graph_input = g->inputs()->create();
53 input->index(graph_input->index());
54 input->shape_status(luci::ShapeStatus::VALID);
55 input->rank(shape.size());
59 auto pack = g->nodes()->create<luci::CirclePack>(1);
60 pack->values(0, input);
64 auto output = g->nodes()->create<luci::CircleOutput>();
66 auto graph_output = g->outputs()->create();
67 output->index(graph_output->index());
74 TEST(SubstitutePackToReshapePass, simple_case)
76 auto graph = loco::make_graph();
77 create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, 0);
78 luci::SubstitutePackToReshapePass pass;
79 while (pass.run(graph.get()))
81 luci::CircleReshape *reshape_node = nullptr;
82 luci::CirclePack *pack_node = nullptr;
83 for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
85 if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
86 reshape_node = reshape;
87 else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
90 ASSERT_NE(nullptr, reshape_node);
91 ASSERT_EQ(nullptr, pack_node);
92 auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
93 ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
94 ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(1));
95 ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(2));
96 ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(3));
97 ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(4));
100 TEST(SubstitutePackToReshapePass, simple_case_neg_axis)
102 auto graph = loco::make_graph();
103 create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, -1);
104 luci::SubstitutePackToReshapePass pass;
105 while (pass.run(graph.get()))
107 luci::CircleReshape *reshape_node = nullptr;
108 luci::CirclePack *pack_node = nullptr;
109 for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
111 if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
112 reshape_node = reshape;
113 else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
116 ASSERT_NE(nullptr, reshape_node);
117 ASSERT_EQ(nullptr, pack_node);
118 auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
119 ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
120 ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(1));
121 ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(2));
122 ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(3));
123 ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(4));