2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <gtest/gtest.h>
19 #include "ir/Layout.h"
20 #include "util/ShapeInference.h"
22 using namespace onert::ir;
24 TEST(ShapeInference, Elementwise)
26 Shape lhs_shape{1, 299, 299, 3};
28 auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
30 ASSERT_EQ(infered_out_shape.rank(), 4);
31 ASSERT_EQ(infered_out_shape.dim(0), 1);
32 ASSERT_EQ(infered_out_shape.dim(1), 299);
33 ASSERT_EQ(infered_out_shape.dim(2), 299);
34 ASSERT_EQ(infered_out_shape.dim(3), 3);
37 TEST(ShapeInference, IncorrectElementwise)
39 Shape lhs_shape{1, 299, 299, 3};
40 Shape rhs_shape{5, 3};
41 ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error);
44 TEST(ShapeInference, Pool2DNodeSame)
46 Shape in_shape{10, 6, 12, 20};
48 Padding padding{PaddingType::SAME};
50 operation::AvgPool2D::Param avg_pool_param{3, 6, stride, padding, Activation::NONE};
51 auto infered_out_shape = onert::shape_inference::inferAvgPoolShape(in_shape, avg_pool_param);
53 ASSERT_EQ(infered_out_shape.rank(), 4);
54 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
55 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
56 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
57 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
59 operation::MaxPool2D::Param max_pool_param{3, 6, stride, padding, Activation::NONE};
60 infered_out_shape = onert::shape_inference::inferMaxPoolShape(in_shape, max_pool_param);
62 ASSERT_EQ(infered_out_shape.rank(), 4);
63 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
64 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
65 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
66 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
69 TEST(ShapeInference, Pool2DNodeValid)
71 Shape in_shape{10, 6, 12, 20};
73 Padding padding{PaddingType::VALID};
75 operation::AvgPool2D::Param avg_pool_param{3, 6, stride, padding, Activation::NONE};
76 auto infered_out_shape = onert::shape_inference::inferAvgPoolShape(in_shape, avg_pool_param);
78 ASSERT_EQ(infered_out_shape.rank(), 4);
79 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
80 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
81 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
82 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
84 operation::MaxPool2D::Param max_pool_param{3, 6, stride, padding, Activation::NONE};
85 infered_out_shape = onert::shape_inference::inferMaxPoolShape(in_shape, max_pool_param);
87 ASSERT_EQ(infered_out_shape.rank(), 4);
88 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
89 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
90 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
91 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
94 TEST(ShapeInference, Pool2DNodeExplicit)
96 Shape in_shape{10, 3, 5, 20};
99 Padding padding{4, 3, 2, 1};
101 operation::AvgPool2D::Param avg_pool_param{3, 6, stride, padding, Activation::NONE};
102 auto infered_out_shape = onert::shape_inference::inferAvgPoolShape(in_shape, avg_pool_param);
104 ASSERT_EQ(infered_out_shape.rank(), 4);
105 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
106 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
107 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
108 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
110 operation::MaxPool2D::Param max_pool_param{3, 6, stride, padding, Activation::NONE};
111 infered_out_shape = onert::shape_inference::inferMaxPoolShape(in_shape, max_pool_param);
113 ASSERT_EQ(infered_out_shape.rank(), 4);
114 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
115 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
116 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
117 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
120 TEST(ShapeInference, Conv2D)
122 Shape in_shape{10, 6, 12, 20};
123 Shape ker_shape{30, 3, 6, 20};
125 operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE};
126 auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
128 ASSERT_EQ(infered_out_shape.rank(), 4);
129 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
130 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
131 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
132 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
134 param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE};
135 infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
137 ASSERT_EQ(infered_out_shape.rank(), 4);
138 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
139 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
140 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
141 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
143 param = operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE};
144 infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
146 ASSERT_EQ(infered_out_shape.rank(), 4);
147 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
148 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
149 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
150 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
153 TEST(ShapeInference, DepthwiseConv2D)
155 Shape in_shape{10, 6, 12, 20};
156 Shape ker_shape{1, 3, 6, 60};
158 operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3,
160 auto infered_out_shape =
161 onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
163 ASSERT_EQ(infered_out_shape.rank(), 4);
164 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
165 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
166 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
167 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
169 param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3,
171 infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
173 ASSERT_EQ(infered_out_shape.rank(), 4);
174 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
175 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
176 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
177 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
179 param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE};
180 infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
182 ASSERT_EQ(infered_out_shape.rank(), 4);
183 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
184 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
185 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
186 ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
189 TEST(ShapeInference, Concat)
192 Shape in1{10, 20, 30, 3, 50};
193 Shape in2{10, 20, 30, 2, 50};
194 Shape in3{10, 20, 30, 2, 50};
196 operation::Concat::Param param{3};
197 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param);
199 ASSERT_EQ(infered_out_shape.rank(), 5);
200 ASSERT_EQ(infered_out_shape.dim(0), 10);
201 ASSERT_EQ(infered_out_shape.dim(1), 20);
202 ASSERT_EQ(infered_out_shape.dim(2), 30);
203 ASSERT_EQ(infered_out_shape.dim(3), 7);
204 ASSERT_EQ(infered_out_shape.dim(4), 50);
207 // case 1. when axis < 0
208 Shape in1{10, 20, 2};
209 Shape in2{10, 20, 3};
211 operation::Concat::Param param{-1};
212 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
214 ASSERT_EQ(infered_out_shape.rank(), 3);
215 ASSERT_EQ(infered_out_shape.dim(0), 10);
216 ASSERT_EQ(infered_out_shape.dim(1), 20);
217 ASSERT_EQ(infered_out_shape.dim(2), 5);
220 // case 2. when axis < 0
224 operation::Concat::Param param{-3};
225 auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
227 ASSERT_EQ(infered_out_shape.rank(), 3);
228 ASSERT_EQ(infered_out_shape.dim(0), 5);
229 ASSERT_EQ(infered_out_shape.dim(1), 20);
230 ASSERT_EQ(infered_out_shape.dim(2), 2);
234 TEST(ShapeInference, neg_Concat)
237 operation::Concat::Param param{2};
239 Shape in2{10, 2, 4}; // dim[1] should be 1 but 2
241 EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
244 operation::Concat::Param param{2};
245 Shape in1{10, 2, 3, 4};
246 Shape in2{10, 2, 4}; // rank should be 4
248 EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
252 TEST(ShapeInference, ExpandDims)
254 Shape in_shape{30, 40};
256 auto check = [&](int32_t axis, Shape &expected) {
257 auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis);
259 ASSERT_EQ(actual.rank(), 3);
260 for (int32_t dim = 0; dim < expected.rank(); dim++)
261 ASSERT_EQ(actual.dim(dim), expected.dim(dim));
266 Shape expected{1, 30, 40};
267 check(axis, expected);
271 Shape expected{30, 40, 1};
272 check(axis, expected);
276 Shape expected{30, 1, 40};
277 check(axis, expected);
279 { // negative boundary
281 Shape expected{30, 40, 1};
282 check(axis, expected);
284 { // negative boundary
286 Shape expected{1, 30, 40};
287 check(axis, expected);
291 TEST(ShapeInference, neg_ExpandDims)
293 Shape in_shape{30, 40};
297 ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
301 ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
305 TEST(ShapeInference, FullyConnected)
307 Shape in_shape{3, 4, 5, 6};
308 Shape ker_shape{3, 10};
309 auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);
311 ASSERT_EQ(infered_out_shape.rank(), 2);
312 ASSERT_EQ(infered_out_shape.dim(0), 36);
313 ASSERT_EQ(infered_out_shape.dim(1), 3);
316 TEST(ShapeInference, Transpose)
318 auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) {
320 ASSERT_EQ(in_shape.rank(), perm.size());
321 ASSERT_EQ(expected.rank(), perm.size());
322 auto inferred_out_shape = onert::shape_inference::inferTransposeShape(in_shape, perm);
324 ASSERT_EQ(inferred_out_shape.rank(), perm.size());
325 for (int32_t dim = 0; dim < expected.rank(); dim++)
327 ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim));
332 Shape in_shape{2, 3};
333 std::vector<int> perm = {1, 0};
334 Shape expected{3, 2};
336 check(in_shape, perm, expected);
340 Shape in_shape{1, 2, 3};
341 std::vector<int> perm = {2, 0, 1};
342 Shape expected{3, 1, 2};
344 check(in_shape, perm, expected);
348 Shape in_shape{1, 2, 3, 4};
349 std::vector<int> perm = {1, 3, 0, 2};
350 Shape expected{2, 4, 1, 3};
352 check(in_shape, perm, expected);
356 TEST(ShapeInference, neg_Transpose)
358 Shape in_shape{1, 2, 3};
359 // Invalid parameter size
361 std::vector<int> perm = {2, 0, 1, 0};
363 ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm), std::runtime_error);
365 // Invalid parameter value
367 std::vector<int> perm = {2, 0, 3};
369 ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm), std::runtime_error);