2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <gtest/gtest.h>
19 #include "ir/Layout.h"
20 #include "util/ShapeInference.h"
22 using namespace onert::ir;
24 TEST(ShapeInference, Elementwise)
26 Shape lhs_shape{1, 299, 299, 3};
28 auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
30 ASSERT_EQ(infered_out_shape.rank(), 4);
31 ASSERT_EQ(infered_out_shape.dim(0), 1);
32 ASSERT_EQ(infered_out_shape.dim(1), 299);
33 ASSERT_EQ(infered_out_shape.dim(2), 299);
34 ASSERT_EQ(infered_out_shape.dim(3), 3);
37 TEST(ShapeInference, neg_Elementwise)
39 Shape lhs_shape{1, 299, 299, 3};
40 Shape rhs_shape{5, 3};
41 ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error);
44 TEST(ShapeInference, Pool2DNodeSame)
46 Shape in_shape{10, 6, 12, 20};
48 Padding padding{PaddingType::SAME};
50 operation::Pool2D::Param avg_pool_param{
51 operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
52 auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
54 ASSERT_EQ(infered_out_shape.rank(), 4);
55 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
56 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
57 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
58 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
60 operation::Pool2D::Param max_pool_param{
61 operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
62 infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
64 ASSERT_EQ(infered_out_shape.rank(), 4);
65 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
66 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
67 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
68 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
71 TEST(ShapeInference, Pool2DNodeValid)
73 Shape in_shape{10, 6, 12, 20};
75 Padding padding{PaddingType::VALID};
77 operation::Pool2D::Param avg_pool_param{
78 operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
79 auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
81 ASSERT_EQ(infered_out_shape.rank(), 4);
82 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
83 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
84 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
85 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
87 operation::Pool2D::Param max_pool_param{
88 operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
89 infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
91 ASSERT_EQ(infered_out_shape.rank(), 4);
92 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
93 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
94 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
95 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
98 TEST(ShapeInference, Pool2DNodeExplicit)
100 Shape in_shape{10, 3, 5, 20};
103 Padding padding{4, 3, 2, 1};
105 operation::Pool2D::Param avg_pool_param{
106 operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
107 auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
109 ASSERT_EQ(infered_out_shape.rank(), 4);
110 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
111 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
112 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
113 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
115 operation::Pool2D::Param max_pool_param{
116 operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
117 infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
119 ASSERT_EQ(infered_out_shape.rank(), 4);
120 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
121 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
122 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
123 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
126 TEST(ShapeInference, neg_Pool2DNode_InvalidStride)
128 Shape in_shape{10, 6, 12, 20};
130 Padding padding{PaddingType::SAME};
132 operation::Pool2D::Param avg_pool_param{
133 operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
134 ASSERT_THROW(onert::shape_inference::inferPoolShape(in_shape, avg_pool_param),
138 TEST(ShapeInference, Conv2D)
140 Shape in_shape{10, 6, 12, 20};
141 Shape ker_shape{30, 3, 6, 20};
143 operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE,
145 auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
147 ASSERT_EQ(infered_out_shape.rank(), 4);
148 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
149 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
150 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
151 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
153 param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE,
155 infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
157 ASSERT_EQ(infered_out_shape.rank(), 4);
158 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
159 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
160 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
161 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
164 operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE, Dilation{1, 1}};
165 infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
167 ASSERT_EQ(infered_out_shape.rank(), 4);
168 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
169 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
170 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
171 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
174 TEST(ShapeInference, neg_Conv2D_InvalidStride)
176 Shape in_shape{10, 6, 12, 20};
177 Shape ker_shape{30, 3, 6, 20};
179 operation::Conv2D::Param param{Stride{0, 0}, Padding{PaddingType::VALID}, Activation::NONE,
181 ASSERT_THROW(onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param),
185 TEST(ShapeInference, DepthwiseConv2D)
187 Shape in_shape{10, 6, 12, 20};
188 Shape ker_shape{1, 3, 6, 60};
190 operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3,
191 Activation::NONE, Dilation{1, 1}};
192 auto infered_out_shape =
193 onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
195 ASSERT_EQ(infered_out_shape.rank(), 4);
196 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
197 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
198 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
199 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
201 param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3,
202 Activation::NONE, Dilation{1, 1}};
203 infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
205 ASSERT_EQ(infered_out_shape.rank(), 4);
206 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
207 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
208 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
209 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
211 param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE,
213 infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
215 ASSERT_EQ(infered_out_shape.rank(), 4);
216 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
217 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
218 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
219 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
222 TEST(ShapeInference, neg_DepthwiseConv2D_InvalidSride)
224 Shape in_shape{10, 6, 12, 20};
225 Shape ker_shape{1, 3, 6, 60};
227 operation::DepthwiseConv2D::Param param{Stride{3, 0}, Padding{PaddingType::VALID}, 3,
228 Activation::NONE, Dilation{1, 1}};
229 ASSERT_THROW(onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param),
233 TEST(ShapeInference, Concat)
236 Shape in1{10, 20, 30, 3, 50};
237 Shape in2{10, 20, 30, 2, 50};
238 Shape in3{10, 20, 30, 2, 50};
240 operation::Concat::Param param{3};
241 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param);
243 ASSERT_EQ(infered_out_shape.rank(), 5);
244 ASSERT_EQ(infered_out_shape.dim(0), 10);
245 ASSERT_EQ(infered_out_shape.dim(1), 20);
246 ASSERT_EQ(infered_out_shape.dim(2), 30);
247 ASSERT_EQ(infered_out_shape.dim(3), 7);
248 ASSERT_EQ(infered_out_shape.dim(4), 50);
251 // case 1. when axis < 0
252 Shape in1{10, 20, 2};
253 Shape in2{10, 20, 3};
255 operation::Concat::Param param{-1};
256 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
258 ASSERT_EQ(infered_out_shape.rank(), 3);
259 ASSERT_EQ(infered_out_shape.dim(0), 10);
260 ASSERT_EQ(infered_out_shape.dim(1), 20);
261 ASSERT_EQ(infered_out_shape.dim(2), 5);
264 // case 2. when axis < 0
268 operation::Concat::Param param{-3};
269 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
271 ASSERT_EQ(infered_out_shape.rank(), 3);
272 ASSERT_EQ(infered_out_shape.dim(0), 5);
273 ASSERT_EQ(infered_out_shape.dim(1), 20);
274 ASSERT_EQ(infered_out_shape.dim(2), 2);
278 TEST(ShapeInference, neg_Concat)
281 operation::Concat::Param param{2};
283 Shape in2{10, 2, 4}; // dim[1] should be 1 but 2
285 EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
288 operation::Concat::Param param{2};
289 Shape in1{10, 2, 3, 4};
290 Shape in2{10, 2, 4}; // rank should be 4
292 EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
296 TEST(ShapeInference, ExpandDims)
298 Shape in_shape{30, 40};
300 auto check = [&](int32_t axis, Shape &expected) {
301 auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis);
303 ASSERT_EQ(actual.rank(), 3);
304 for (int32_t dim = 0; dim < expected.rank(); dim++)
305 ASSERT_EQ(actual.dim(dim), expected.dim(dim));
310 Shape expected{1, 30, 40};
311 check(axis, expected);
315 Shape expected{30, 40, 1};
316 check(axis, expected);
320 Shape expected{30, 1, 40};
321 check(axis, expected);
323 { // negative boundary
325 Shape expected{30, 40, 1};
326 check(axis, expected);
328 { // negative boundary
330 Shape expected{1, 30, 40};
331 check(axis, expected);
335 TEST(ShapeInference, neg_ExpandDims)
337 Shape in_shape{30, 40};
341 ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
345 ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
349 TEST(ShapeInference, FullyConnected)
351 Shape in_shape{3, 4, 5, 6};
352 Shape ker_shape{3, 10};
353 auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);
355 ASSERT_EQ(infered_out_shape.rank(), 2);
356 ASSERT_EQ(infered_out_shape.dim(0), 36);
357 ASSERT_EQ(infered_out_shape.dim(1), 3);
360 TEST(ShapeInference, Transpose)
362 auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) {
364 ASSERT_EQ(in_shape.rank(), perm.size());
365 ASSERT_EQ(expected.rank(), perm.size());
366 auto inferred_out_shape =
367 onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size());
369 ASSERT_EQ(inferred_out_shape.rank(), perm.size());
370 for (int32_t dim = 0; dim < expected.rank(); dim++)
372 ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim));
377 Shape in_shape{2, 3};
378 std::vector<int> perm = {1, 0};
379 Shape expected{3, 2};
381 check(in_shape, perm, expected);
385 Shape in_shape{1, 2, 3};
386 std::vector<int> perm = {2, 0, 1};
387 Shape expected{3, 1, 2};
389 check(in_shape, perm, expected);
393 Shape in_shape{1, 2, 3, 4};
394 std::vector<int> perm = {1, 3, 0, 2};
395 Shape expected{2, 4, 1, 3};
397 check(in_shape, perm, expected);
401 TEST(ShapeInference, neg_Transpose)
403 Shape in_shape{1, 2, 3};
404 // Invalid parameter size
406 std::vector<int> perm = {2, 0, 1, 0};
408 ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
411 // Invalid parameter value
413 std::vector<int> perm = {2, 0, 3};
415 ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
420 TEST(ShapeInference, Gather)
422 auto check = [&](Shape &input, Shape &indices, Shape &expected, int32_t axis) {
423 int rank = input.rank();
424 auto actual = onert::shape_inference::inferGatherShape(input, indices, axis, rank);
426 ASSERT_EQ(actual.rank(), expected.rank());
428 for (int32_t dim = 0; dim < expected.rank(); dim++)
429 ASSERT_EQ(actual.dim(dim), expected.dim(dim));
432 // check for 2-D, 3-D, axis 0
435 Shape indices{1, 1, 2};
437 Shape expected{1, 1, 2, 4};
438 check(input, indices, expected, axis);
441 // check for 2-D, 3-D, axis 1
444 Shape indices{1, 2, 1};
446 Shape expected{3, 1, 2, 1};
447 check(input, indices, expected, axis);
450 // check for 3-D, 2-D, axis 0
452 Shape input{2, 3, 4};
455 Shape expected{1, 2, 3, 4};
456 check(input, indices, expected, axis);
459 // check for 3-D, 2-D, axis 2
461 Shape input{2, 3, 4};
464 Shape expected{2, 3, 2, 1};
465 check(input, indices, expected, axis);
468 // check for 4D, axis 0
470 Shape input{1, 2, 3, 4};
473 Shape expected{2, 2, 3, 4};
474 check(input, indices, expected, axis);
478 TEST(ShapeInference, BCQFullyConnected)
480 auto check = [&](Shape &in_shape, Shape &cluster_shape, std::vector<int> cluster,
482 auto actual = onert::shape_inference::inferBCQFullyConnectedShape(in_shape, cluster_shape,
484 ASSERT_EQ(actual.rank(), expected.rank());
486 for (int32_t dim = 0; dim < expected.rank(); dim++)
487 ASSERT_EQ(actual.dim(dim), expected.dim(dim));
491 Shape in_shape{10, 1};
492 Shape cluster_shape{3, 2};
493 std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
495 Shape expected{30, 1};
496 check(in_shape, cluster_shape, cluster, expected);
500 Shape in_shape{1, 1};
501 Shape cluster_shape{1, 2};
502 std::vector<int> cluster = {3, 50};
504 Shape expected{50, 1};
505 check(in_shape, cluster_shape, cluster, expected);
509 TEST(ShapeInference, BCQGather)
511 auto check = [&](Shape &indices_shape, Shape &cluster_shape, std::vector<int> cluster,
512 uint32_t hidden_size, uint32_t axis, int rank, Shape &expected) {
513 operation::BCQGather::Param param{hidden_size, axis};
514 auto actual = onert::shape_inference::inferBCQGatherShape(indices_shape, cluster_shape,
515 cluster.data(), rank, param);
516 ASSERT_EQ(actual.rank(), expected.rank());
518 for (int32_t dim = 0; dim < expected.rank(); dim++)
519 ASSERT_EQ(actual.dim(dim), expected.dim(dim));
523 Shape indices_shape{5, 1};
524 Shape cluster_shape{3, 2};
525 std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
526 uint32_t hidden_size = 10;
530 Shape expected{5, 1, 10};
531 check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);
535 Shape indices_shape{5, 1};
536 Shape cluster_shape{3, 2};
537 std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
538 uint32_t hidden_size = 10;
542 Shape expected{30, 5, 1};
543 check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);