63a948d7b15baaca8b50d2ca319f373f80d3b0d0
[platform/core/ml/nnfw.git] / runtime / onert / test / util / ShapeInference.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <gtest/gtest.h>
18
19 #include "ir/Layout.h"
20 #include "util/ShapeInference.h"
21
22 using namespace onert::ir;
23
24 TEST(ShapeInference, Elementwise)
25 {
26   Shape lhs_shape{1, 299, 299, 3};
27   Shape rhs_shape{3};
28   auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
29
30   ASSERT_EQ(infered_out_shape.rank(), 4);
31   ASSERT_EQ(infered_out_shape.dim(0), 1);
32   ASSERT_EQ(infered_out_shape.dim(1), 299);
33   ASSERT_EQ(infered_out_shape.dim(2), 299);
34   ASSERT_EQ(infered_out_shape.dim(3), 3);
35 }
36
37 TEST(ShapeInference, IncorrectElementwise)
38 {
39   Shape lhs_shape{1, 299, 299, 3};
40   Shape rhs_shape{5, 3};
41   ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error);
42 }
43
44 TEST(ShapeInference, Pool2DNodeSame)
45 {
46   Shape in_shape{10, 6, 12, 20};
47   Stride stride{3, 7};
48   Padding padding{PaddingType::SAME};
49
50   operation::AvgPool2D::Param avg_pool_param{3, 6, stride, padding, Activation::NONE};
51   auto infered_out_shape = onert::shape_inference::inferAvgPoolShape(in_shape, avg_pool_param);
52
53   ASSERT_EQ(infered_out_shape.rank(), 4);
54   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
55   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
56   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
57   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
58
59   operation::MaxPool2D::Param max_pool_param{3, 6, stride, padding, Activation::NONE};
60   infered_out_shape = onert::shape_inference::inferMaxPoolShape(in_shape, max_pool_param);
61
62   ASSERT_EQ(infered_out_shape.rank(), 4);
63   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
64   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
65   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
66   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
67 }
68
69 TEST(ShapeInference, Pool2DNodeValid)
70 {
71   Shape in_shape{10, 6, 12, 20};
72   Stride stride{3, 7};
73   Padding padding{PaddingType::VALID};
74
75   operation::AvgPool2D::Param avg_pool_param{3, 6, stride, padding, Activation::NONE};
76   auto infered_out_shape = onert::shape_inference::inferAvgPoolShape(in_shape, avg_pool_param);
77
78   ASSERT_EQ(infered_out_shape.rank(), 4);
79   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
80   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
81   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
82   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
83
84   operation::MaxPool2D::Param max_pool_param{3, 6, stride, padding, Activation::NONE};
85   infered_out_shape = onert::shape_inference::inferMaxPoolShape(in_shape, max_pool_param);
86
87   ASSERT_EQ(infered_out_shape.rank(), 4);
88   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
89   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
90   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
91   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
92 }
93
94 TEST(ShapeInference, Pool2DNodeExplicit)
95 {
96   Shape in_shape{10, 3, 5, 20};
97
98   Stride stride{3, 7};
99   Padding padding{4, 3, 2, 1};
100
101   operation::AvgPool2D::Param avg_pool_param{3, 6, stride, padding, Activation::NONE};
102   auto infered_out_shape = onert::shape_inference::inferAvgPoolShape(in_shape, avg_pool_param);
103
104   ASSERT_EQ(infered_out_shape.rank(), 4);
105   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
106   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
107   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
108   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
109
110   operation::MaxPool2D::Param max_pool_param{3, 6, stride, padding, Activation::NONE};
111   infered_out_shape = onert::shape_inference::inferMaxPoolShape(in_shape, max_pool_param);
112
113   ASSERT_EQ(infered_out_shape.rank(), 4);
114   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
115   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
116   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
117   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
118 }
119
120 TEST(ShapeInference, Conv2D)
121 {
122   Shape in_shape{10, 6, 12, 20};
123   Shape ker_shape{30, 3, 6, 20};
124
125   operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE};
126   auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
127
128   ASSERT_EQ(infered_out_shape.rank(), 4);
129   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
130   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
131   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
132   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
133
134   param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE};
135   infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
136
137   ASSERT_EQ(infered_out_shape.rank(), 4);
138   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
139   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
140   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
141   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
142
143   param = operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE};
144   infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
145
146   ASSERT_EQ(infered_out_shape.rank(), 4);
147   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
148   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
149   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
150   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
151 }
152
153 TEST(ShapeInference, DepthwiseConv2D)
154 {
155   Shape in_shape{10, 6, 12, 20};
156   Shape ker_shape{1, 3, 6, 60};
157
158   operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3,
159                                           Activation::NONE};
160   auto infered_out_shape =
161       onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
162
163   ASSERT_EQ(infered_out_shape.rank(), 4);
164   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
165   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
166   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
167   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
168
169   param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3,
170                                             Activation::NONE};
171   infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
172
173   ASSERT_EQ(infered_out_shape.rank(), 4);
174   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
175   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
176   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
177   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
178
179   param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE};
180   infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
181
182   ASSERT_EQ(infered_out_shape.rank(), 4);
183   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
184   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
185   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
186   ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
187 }
188
189 TEST(ShapeInference, Concat)
190 {
191   {
192     Shape in1{10, 20, 30, 3, 50};
193     Shape in2{10, 20, 30, 2, 50};
194     Shape in3{10, 20, 30, 2, 50};
195
196     operation::Concat::Param param{3};
197     auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param);
198
199     ASSERT_EQ(infered_out_shape.rank(), 5);
200     ASSERT_EQ(infered_out_shape.dim(0), 10);
201     ASSERT_EQ(infered_out_shape.dim(1), 20);
202     ASSERT_EQ(infered_out_shape.dim(2), 30);
203     ASSERT_EQ(infered_out_shape.dim(3), 7);
204     ASSERT_EQ(infered_out_shape.dim(4), 50);
205   }
206   {
207     // case 1. when axis < 0
208     Shape in1{10, 20, 2};
209     Shape in2{10, 20, 3};
210
211     operation::Concat::Param param{-1};
212     auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
213
214     ASSERT_EQ(infered_out_shape.rank(), 3);
215     ASSERT_EQ(infered_out_shape.dim(0), 10);
216     ASSERT_EQ(infered_out_shape.dim(1), 20);
217     ASSERT_EQ(infered_out_shape.dim(2), 5);
218   }
219   {
220     // case 2. when axis < 0
221     Shape in1{2, 20, 2};
222     Shape in2{3, 20, 2};
223
224     operation::Concat::Param param{-3};
225     auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
226
227     ASSERT_EQ(infered_out_shape.rank(), 3);
228     ASSERT_EQ(infered_out_shape.dim(0), 5);
229     ASSERT_EQ(infered_out_shape.dim(1), 20);
230     ASSERT_EQ(infered_out_shape.dim(2), 2);
231   }
232 }
233
234 TEST(ShapeInference, neg_Concat)
235 {
236   {
237     operation::Concat::Param param{2};
238     Shape in1{10, 1, 3};
239     Shape in2{10, 2, 4}; // dim[1] should be 1 but 2
240
241     EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
242   }
243   { // wrong rank
244     operation::Concat::Param param{2};
245     Shape in1{10, 2, 3, 4};
246     Shape in2{10, 2, 4}; // rank should be 4
247
248     EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
249   }
250 }
251
252 TEST(ShapeInference, ExpandDims)
253 {
254   Shape in_shape{30, 40};
255
256   auto check = [&](int32_t axis, Shape &expected) {
257     auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis);
258
259     ASSERT_EQ(actual.rank(), 3);
260     for (int32_t dim = 0; dim < expected.rank(); dim++)
261       ASSERT_EQ(actual.dim(dim), expected.dim(dim));
262   };
263
264   { // boundary
265     int32_t axis = 0;
266     Shape expected{1, 30, 40};
267     check(axis, expected);
268   }
269   { // boundary
270     int32_t axis = 2;
271     Shape expected{30, 40, 1};
272     check(axis, expected);
273   }
274   { // inside
275     int32_t axis = 1;
276     Shape expected{30, 1, 40};
277     check(axis, expected);
278   }
279   { // negative boundary
280     int32_t axis = -1;
281     Shape expected{30, 40, 1};
282     check(axis, expected);
283   }
284   { // negative boundary
285     int32_t axis = -3;
286     Shape expected{1, 30, 40};
287     check(axis, expected);
288   }
289 }
290
291 TEST(ShapeInference, neg_ExpandDims)
292 {
293   Shape in_shape{30, 40};
294
295   { // over boundary
296     int32_t axis = 3;
297     ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
298   }
299   { // over boundary
300     int32_t axis = -4;
301     ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
302   }
303 }
304
305 TEST(ShapeInference, FullyConnected)
306 {
307   Shape in_shape{3, 4, 5, 6};
308   Shape ker_shape{3, 10};
309   auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);
310
311   ASSERT_EQ(infered_out_shape.rank(), 2);
312   ASSERT_EQ(infered_out_shape.dim(0), 36);
313   ASSERT_EQ(infered_out_shape.dim(1), 3);
314 }
315
316 TEST(ShapeInference, Transpose)
317 {
318   auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) {
319     // pre-conditions
320     ASSERT_EQ(in_shape.rank(), perm.size());
321     ASSERT_EQ(expected.rank(), perm.size());
322     auto inferred_out_shape = onert::shape_inference::inferTransposeShape(in_shape, perm);
323     // post-conditions
324     ASSERT_EQ(inferred_out_shape.rank(), perm.size());
325     for (int32_t dim = 0; dim < expected.rank(); dim++)
326     {
327       ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim));
328     }
329   };
330   // check for 2-D
331   {
332     Shape in_shape{2, 3};
333     std::vector<int> perm = {1, 0};
334     Shape expected{3, 2};
335     // int32_t rank = 2;
336     check(in_shape, perm, expected);
337   }
338   // check for 3-D
339   {
340     Shape in_shape{1, 2, 3};
341     std::vector<int> perm = {2, 0, 1};
342     Shape expected{3, 1, 2};
343     // int32_t rank = 3;
344     check(in_shape, perm, expected);
345   }
346   // check for 4-D
347   {
348     Shape in_shape{1, 2, 3, 4};
349     std::vector<int> perm = {1, 3, 0, 2};
350     Shape expected{2, 4, 1, 3};
351     // int32_t rank = 4;
352     check(in_shape, perm, expected);
353   }
354 }
355
356 TEST(ShapeInference, neg_Transpose)
357 {
358   Shape in_shape{1, 2, 3};
359   // Invalid parameter size
360   {
361     std::vector<int> perm = {2, 0, 1, 0};
362     // int32_t rank = 3;
363     ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm), std::runtime_error);
364   }
365   // Invalid parameter value
366   {
367     std::vector<int> perm = {2, 0, 3};
368     // int32_t rank = 3;
369     ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm), std::runtime_error);
370   }
371 }