2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <gtest/gtest.h>
19 #include "ir/Layout.h"
20 #include "util/ShapeInference.h"
22 using namespace onert::ir;
24 TEST(ShapeInference, Elementwise)
26 Shape lhs_shape{1, 299, 299, 3};
28 auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
30 ASSERT_EQ(infered_out_shape.rank(), 4);
31 ASSERT_EQ(infered_out_shape.dim(0), 1);
32 ASSERT_EQ(infered_out_shape.dim(1), 299);
33 ASSERT_EQ(infered_out_shape.dim(2), 299);
34 ASSERT_EQ(infered_out_shape.dim(3), 3);
37 TEST(ShapeInference, IncorrectElementwise)
39 Shape lhs_shape{1, 299, 299, 3};
40 Shape rhs_shape{5, 3};
41 ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error);
44 TEST(ShapeInference, Pool2DNodeSame)
46 Shape in_shape{10, 6, 12, 20};
48 Padding padding{PaddingType::SAME};
50 operation::Pool2D::Param avg_pool_param{
51 operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
52 auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
54 ASSERT_EQ(infered_out_shape.rank(), 4);
55 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
56 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
57 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
58 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
60 operation::Pool2D::Param max_pool_param{
61 operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
62 infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
64 ASSERT_EQ(infered_out_shape.rank(), 4);
65 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
66 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
67 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
68 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
71 TEST(ShapeInference, Pool2DNodeValid)
73 Shape in_shape{10, 6, 12, 20};
75 Padding padding{PaddingType::VALID};
77 operation::Pool2D::Param avg_pool_param{
78 operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
79 auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
81 ASSERT_EQ(infered_out_shape.rank(), 4);
82 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
83 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
84 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
85 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
87 operation::Pool2D::Param max_pool_param{
88 operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
89 infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
91 ASSERT_EQ(infered_out_shape.rank(), 4);
92 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
93 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
94 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
95 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
98 TEST(ShapeInference, Pool2DNodeExplicit)
100 Shape in_shape{10, 3, 5, 20};
103 Padding padding{4, 3, 2, 1};
105 operation::Pool2D::Param avg_pool_param{
106 operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
107 auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
109 ASSERT_EQ(infered_out_shape.rank(), 4);
110 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
111 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
112 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
113 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
115 operation::Pool2D::Param max_pool_param{
116 operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
117 infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
119 ASSERT_EQ(infered_out_shape.rank(), 4);
120 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
121 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
122 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
123 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
126 TEST(ShapeInference, Conv2D)
128 Shape in_shape{10, 6, 12, 20};
129 Shape ker_shape{30, 3, 6, 20};
131 operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE,
133 auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
135 ASSERT_EQ(infered_out_shape.rank(), 4);
136 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
137 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
138 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
139 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
141 param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE,
143 infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
145 ASSERT_EQ(infered_out_shape.rank(), 4);
146 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
147 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
148 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
149 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
152 operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE, Dilation{1, 1}};
153 infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
155 ASSERT_EQ(infered_out_shape.rank(), 4);
156 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
157 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
158 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
159 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
162 TEST(ShapeInference, DepthwiseConv2D)
164 Shape in_shape{10, 6, 12, 20};
165 Shape ker_shape{1, 3, 6, 60};
167 operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3,
169 auto infered_out_shape =
170 onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
172 ASSERT_EQ(infered_out_shape.rank(), 4);
173 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
174 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
175 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
176 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
178 param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3,
180 infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
182 ASSERT_EQ(infered_out_shape.rank(), 4);
183 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
184 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
185 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
186 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
188 param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE};
189 infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
191 ASSERT_EQ(infered_out_shape.rank(), 4);
192 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
193 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
194 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
195 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
198 TEST(ShapeInference, Concat)
201 Shape in1{10, 20, 30, 3, 50};
202 Shape in2{10, 20, 30, 2, 50};
203 Shape in3{10, 20, 30, 2, 50};
205 operation::Concat::Param param{3};
206 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param);
208 ASSERT_EQ(infered_out_shape.rank(), 5);
209 ASSERT_EQ(infered_out_shape.dim(0), 10);
210 ASSERT_EQ(infered_out_shape.dim(1), 20);
211 ASSERT_EQ(infered_out_shape.dim(2), 30);
212 ASSERT_EQ(infered_out_shape.dim(3), 7);
213 ASSERT_EQ(infered_out_shape.dim(4), 50);
216 // case 1. when axis < 0
217 Shape in1{10, 20, 2};
218 Shape in2{10, 20, 3};
220 operation::Concat::Param param{-1};
221 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
223 ASSERT_EQ(infered_out_shape.rank(), 3);
224 ASSERT_EQ(infered_out_shape.dim(0), 10);
225 ASSERT_EQ(infered_out_shape.dim(1), 20);
226 ASSERT_EQ(infered_out_shape.dim(2), 5);
229 // case 2. when axis < 0
233 operation::Concat::Param param{-3};
234 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
236 ASSERT_EQ(infered_out_shape.rank(), 3);
237 ASSERT_EQ(infered_out_shape.dim(0), 5);
238 ASSERT_EQ(infered_out_shape.dim(1), 20);
239 ASSERT_EQ(infered_out_shape.dim(2), 2);
243 TEST(ShapeInference, neg_Concat)
246 operation::Concat::Param param{2};
248 Shape in2{10, 2, 4}; // dim[1] should be 1 but 2
250 EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
253 operation::Concat::Param param{2};
254 Shape in1{10, 2, 3, 4};
255 Shape in2{10, 2, 4}; // rank should be 4
257 EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
261 TEST(ShapeInference, ExpandDims)
263 Shape in_shape{30, 40};
265 auto check = [&](int32_t axis, Shape &expected) {
266 auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis);
268 ASSERT_EQ(actual.rank(), 3);
269 for (int32_t dim = 0; dim < expected.rank(); dim++)
270 ASSERT_EQ(actual.dim(dim), expected.dim(dim));
275 Shape expected{1, 30, 40};
276 check(axis, expected);
280 Shape expected{30, 40, 1};
281 check(axis, expected);
285 Shape expected{30, 1, 40};
286 check(axis, expected);
288 { // negative boundary
290 Shape expected{30, 40, 1};
291 check(axis, expected);
293 { // negative boundary
295 Shape expected{1, 30, 40};
296 check(axis, expected);
300 TEST(ShapeInference, neg_ExpandDims)
302 Shape in_shape{30, 40};
306 ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
310 ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
314 TEST(ShapeInference, FullyConnected)
316 Shape in_shape{3, 4, 5, 6};
317 Shape ker_shape{3, 10};
318 auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);
320 ASSERT_EQ(infered_out_shape.rank(), 2);
321 ASSERT_EQ(infered_out_shape.dim(0), 36);
322 ASSERT_EQ(infered_out_shape.dim(1), 3);
325 TEST(ShapeInference, Transpose)
327 auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) {
329 ASSERT_EQ(in_shape.rank(), perm.size());
330 ASSERT_EQ(expected.rank(), perm.size());
331 auto inferred_out_shape = onert::shape_inference::inferTransposeShape(in_shape, perm);
333 ASSERT_EQ(inferred_out_shape.rank(), perm.size());
334 for (int32_t dim = 0; dim < expected.rank(); dim++)
336 ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim));
341 Shape in_shape{2, 3};
342 std::vector<int> perm = {1, 0};
343 Shape expected{3, 2};
345 check(in_shape, perm, expected);
349 Shape in_shape{1, 2, 3};
350 std::vector<int> perm = {2, 0, 1};
351 Shape expected{3, 1, 2};
353 check(in_shape, perm, expected);
357 Shape in_shape{1, 2, 3, 4};
358 std::vector<int> perm = {1, 3, 0, 2};
359 Shape expected{2, 4, 1, 3};
361 check(in_shape, perm, expected);
365 TEST(ShapeInference, neg_Transpose)
367 Shape in_shape{1, 2, 3};
368 // Invalid parameter size
370 std::vector<int> perm = {2, 0, 1, 0};
372 ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm), std::runtime_error);
374 // Invalid parameter value
376 std::vector<int> perm = {2, 0, 3};
378 ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm), std::runtime_error);