2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/DynamicBatchToSingleBatchPass.h"
21 #include <luci/IR/CircleNodes.h>
23 #include <gtest/gtest.h>
28 std::unique_ptr<loco::TensorShape> make_tshape(std::initializer_list<uint32_t> dims)
30 auto tensor_shape = std::make_unique<loco::TensorShape>();
32 tensor_shape->rank(dims.size());
34 for (auto it = dims.begin(); it != dims.end(); ++it)
36 tensor_shape->dim(axis++) = *it;
40 return std::move(tensor_shape);
45 TEST(DynamicBatchToSingleBatchPassTest, simple)
47 luci::DynamicBatchToSingleBatchPass pass;
49 auto g = loco::make_graph();
51 auto graph_input = g->inputs()->create();
53 auto tensor_shape = make_tshape({1, 5, 5, 3});
54 tensor_shape->dim(0).unset();
55 graph_input->shape(std::move(tensor_shape));
58 // Create nodes to make relu traversed first
59 auto input = g->nodes()->create<luci::CircleInput>();
62 input->shape({1, 5, 5, 3});
63 input->dim(0).unset();
66 EXPECT_FALSE(graph_input->shape()->dim(0).known());
67 EXPECT_FALSE(input->dim(0).known());
69 EXPECT_TRUE(pass.run(g.get()));
71 // Check input is knwon
72 EXPECT_TRUE(graph_input->shape()->dim(0).known());
73 EXPECT_EQ(1, graph_input->shape()->dim(0));
74 EXPECT_TRUE(input->dim(0).known());
75 EXPECT_EQ(1, input->dim(0));
78 TEST(DynamicBatchToSingleBatchPassTest, simple_NEG)
80 luci::DynamicBatchToSingleBatchPass pass;
82 auto g = loco::make_graph();
84 auto graph_input = g->inputs()->create();
86 graph_input->shape({1, 5, 5, 3});
89 // Create nodes to make relu traversed first
90 auto input = g->nodes()->create<luci::CircleInput>();
93 input->shape({1, 5, 5, 3});
96 EXPECT_FALSE(pass.run(g.get()));
99 // Remove this test if we support rank 1 in this pass
100 TEST(DynamicBatchToSingleBatchPassTest, rank1_NEG)
102 luci::DynamicBatchToSingleBatchPass pass;
104 auto g = loco::make_graph();
106 auto graph_input = g->inputs()->create();
108 auto tensor_shape = make_tshape({1});
109 tensor_shape->dim(0).unset();
110 graph_input->shape(std::move(tensor_shape));
113 // Create nodes to make relu traversed first
114 auto input = g->nodes()->create<luci::CircleInput>();
118 input->dim(0).unset();
121 EXPECT_FALSE(graph_input->shape()->dim(0).known());
122 EXPECT_FALSE(input->dim(0).known());
124 // Rank 1 is unsupported for now
125 EXPECT_ANY_THROW(pass.run(g.get()));