2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loco/IR/Nodes.h"
18 #include "loco/IR/CanonicalDialect.h"
20 #include <gtest/gtest.h>
22 TEST(PushTest, constructor)
26 ASSERT_EQ(loco::CanonicalDialect::get(), push_node.dialect());
27 ASSERT_EQ(loco::CanonicalOpcode::Push, push_node.opcode());
29 ASSERT_FALSE(push_node.indexed());
34 const std::vector<uint32_t> dims{1, 8, 16, 3};
38 push_node.shape({dims[0], dims[1], dims[2], dims[3]});
40 ASSERT_EQ(dims.size(), push_node.rank());
41 ASSERT_EQ(dims[0], push_node.dim(0));
42 ASSERT_EQ(dims[1], push_node.dim(1));
43 ASSERT_EQ(dims[2], push_node.dim(2));
44 ASSERT_EQ(dims[3], push_node.dim(3));
47 TEST(PullTest, constructor)
51 ASSERT_EQ(loco::CanonicalDialect::get(), pull_node.dialect());
52 ASSERT_EQ(loco::CanonicalOpcode::Pull, pull_node.opcode());
54 ASSERT_FALSE(pull_node.indexed());
56 ASSERT_EQ(loco::DataType::Unknown, pull_node.dtype());
57 ASSERT_EQ(0, pull_node.rank());
62 const std::vector<uint32_t> dims{1, 8, 16, 3};
66 pull_node.shape({dims[0], dims[1], dims[2], dims[3]});
68 ASSERT_EQ(dims.size(), pull_node.rank());
69 ASSERT_EQ(dims[0], pull_node.dim(0));
70 ASSERT_EQ(dims[1], pull_node.dim(1));
71 ASSERT_EQ(dims[2], pull_node.dim(2));
72 ASSERT_EQ(dims[3], pull_node.dim(3));
75 TEST(ForwardTest, constructor)
77 loco::Forward forward_node;
79 ASSERT_EQ(loco::CanonicalDialect::get(), forward_node.dialect());
80 ASSERT_EQ(loco::CanonicalOpcode::Forward, forward_node.opcode());
82 ASSERT_EQ(nullptr, forward_node.input());
85 TEST(ReLUTest, constructor)
89 ASSERT_EQ(loco::CanonicalDialect::get(), relu_node.dialect());
90 ASSERT_EQ(loco::CanonicalOpcode::ReLU, relu_node.opcode());
92 ASSERT_EQ(nullptr, relu_node.input());
95 TEST(ReLU6Test, constructor)
97 loco::ReLU6 relu6_node;
99 ASSERT_EQ(loco::CanonicalDialect::get(), relu6_node.dialect());
100 ASSERT_EQ(loco::CanonicalOpcode::ReLU6, relu6_node.opcode());
102 ASSERT_EQ(nullptr, relu6_node.input());
105 TEST(ConstGenTest, constructor)
107 loco::ConstGen constgen_node;
109 ASSERT_EQ(loco::CanonicalDialect::get(), constgen_node.dialect());
110 ASSERT_EQ(loco::CanonicalOpcode::ConstGen, constgen_node.opcode());
112 ASSERT_EQ(loco::DataType::Unknown, constgen_node.dtype());
113 ASSERT_EQ(0, constgen_node.rank());
115 constgen_node.dtype(loco::DataType::FLOAT32);
116 ASSERT_EQ(loco::DataType::FLOAT32, constgen_node.dtype());
118 constgen_node.rank(2);
119 ASSERT_EQ(2, constgen_node.rank());
121 constgen_node.dim(0) = 2;
122 constgen_node.dim(1) = 3;
124 ASSERT_TRUE(constgen_node.dim(0).known());
125 ASSERT_TRUE(constgen_node.dim(1).known());
127 ASSERT_EQ(2, constgen_node.dim(0));
128 ASSERT_EQ(3, constgen_node.dim(1));
130 constgen_node.size<loco::DataType::FLOAT32>(6);
132 ASSERT_EQ(6, constgen_node.size<loco::DataType::FLOAT32>());
134 constgen_node.at<loco::DataType::FLOAT32>(0) = 0.0f; // Set 0,0
135 constgen_node.at<loco::DataType::FLOAT32>(1) = 1.0f; // Set 0,1
136 constgen_node.at<loco::DataType::FLOAT32>(2) = 2.0f; // Set 0,2
137 constgen_node.at<loco::DataType::FLOAT32>(3) = 3.0f; // Set 1,0
138 constgen_node.at<loco::DataType::FLOAT32>(4) = 4.0f; // Set 1,1
139 constgen_node.at<loco::DataType::FLOAT32>(5) = 5.0f; // Set 1,2
141 ASSERT_EQ(0.0f, constgen_node.at<loco::DataType::FLOAT32>(0));
142 ASSERT_EQ(1.0f, constgen_node.at<loco::DataType::FLOAT32>(1));
143 ASSERT_EQ(2.0f, constgen_node.at<loco::DataType::FLOAT32>(2));
144 ASSERT_EQ(3.0f, constgen_node.at<loco::DataType::FLOAT32>(3));
145 ASSERT_EQ(4.0f, constgen_node.at<loco::DataType::FLOAT32>(4));
146 ASSERT_EQ(5.0f, constgen_node.at<loco::DataType::FLOAT32>(5));
149 TEST(ConstGenTest, constructor_s32)
151 loco::ConstGen constgen_node;
153 ASSERT_EQ(loco::DataType::Unknown, constgen_node.dtype());
154 ASSERT_EQ(0, constgen_node.rank());
156 constgen_node.dtype(loco::DataType::S32);
157 ASSERT_EQ(loco::DataType::S32, constgen_node.dtype());
159 constgen_node.rank(2);
160 ASSERT_EQ(2, constgen_node.rank());
162 constgen_node.dim(0) = 2;
163 constgen_node.dim(1) = 3;
165 ASSERT_TRUE(constgen_node.dim(0).known());
166 ASSERT_TRUE(constgen_node.dim(1).known());
168 ASSERT_EQ(2, constgen_node.dim(0));
169 ASSERT_EQ(3, constgen_node.dim(1));
171 constgen_node.size<loco::DataType::S32>(6);
173 ASSERT_EQ(6, constgen_node.size<loco::DataType::S32>());
175 constgen_node.at<loco::DataType::S32>(0) = 0; // Set 0,0
176 constgen_node.at<loco::DataType::S32>(1) = 1; // Set 0,1
177 constgen_node.at<loco::DataType::S32>(2) = 2; // Set 0,2
178 constgen_node.at<loco::DataType::S32>(3) = -3; // Set 1,0
179 constgen_node.at<loco::DataType::S32>(4) = -4; // Set 1,1
180 constgen_node.at<loco::DataType::S32>(5) = -5; // Set 1,2
182 ASSERT_EQ(0, constgen_node.at<loco::DataType::S32>(0));
183 ASSERT_EQ(1, constgen_node.at<loco::DataType::S32>(1));
184 ASSERT_EQ(2, constgen_node.at<loco::DataType::S32>(2));
185 ASSERT_EQ(-3, constgen_node.at<loco::DataType::S32>(3));
186 ASSERT_EQ(-4, constgen_node.at<loco::DataType::S32>(4));
187 ASSERT_EQ(-5, constgen_node.at<loco::DataType::S32>(5));
190 TEST(MaxPool2DTest, constructor)
192 loco::MaxPool2D maxpool_node;
194 ASSERT_EQ(loco::CanonicalDialect::get(), maxpool_node.dialect());
195 ASSERT_EQ(loco::CanonicalOpcode::MaxPool2D, maxpool_node.opcode());
197 ASSERT_EQ(nullptr, maxpool_node.ifm());
199 ASSERT_EQ(0, maxpool_node.pad()->top());
200 ASSERT_EQ(0, maxpool_node.pad()->bottom());
201 ASSERT_EQ(0, maxpool_node.pad()->left());
202 ASSERT_EQ(0, maxpool_node.pad()->right());
204 ASSERT_EQ(1, maxpool_node.window()->vertical());
205 ASSERT_EQ(1, maxpool_node.window()->horizontal());
207 ASSERT_EQ(1, maxpool_node.stride()->vertical());
208 ASSERT_EQ(1, maxpool_node.stride()->horizontal());
211 TEST(MaxPool2DTest, pad)
213 const uint32_t t = 1;
214 const uint32_t b = 2;
215 const uint32_t l = 3;
216 const uint32_t r = 4;
218 loco::MaxPool2D maxpool_node;
220 maxpool_node.pad()->top(t);
221 ASSERT_EQ(t, maxpool_node.pad()->top());
223 maxpool_node.pad()->bottom(b);
224 ASSERT_EQ(b, maxpool_node.pad()->bottom());
226 maxpool_node.pad()->left(l);
227 ASSERT_EQ(l, maxpool_node.pad()->left());
229 maxpool_node.pad()->right(r);
230 ASSERT_EQ(r, maxpool_node.pad()->right());
233 TEST(AvgPool2DTest, constructor)
235 loco::AvgPool2D avgpool_node;
237 ASSERT_EQ(loco::CanonicalDialect::get(), avgpool_node.dialect());
238 ASSERT_EQ(loco::CanonicalOpcode::AvgPool2D, avgpool_node.opcode());
240 ASSERT_EQ(nullptr, avgpool_node.ifm());
242 ASSERT_EQ(loco::AvgPool2D::Convention::Unknown, avgpool_node.convention());
244 ASSERT_EQ(0, avgpool_node.pad()->top());
245 ASSERT_EQ(0, avgpool_node.pad()->bottom());
246 ASSERT_EQ(0, avgpool_node.pad()->left());
247 ASSERT_EQ(0, avgpool_node.pad()->right());
249 ASSERT_EQ(1, avgpool_node.window()->vertical());
250 ASSERT_EQ(1, avgpool_node.window()->horizontal());
252 ASSERT_EQ(1, avgpool_node.stride()->vertical());
253 ASSERT_EQ(1, avgpool_node.stride()->horizontal());
256 TEST(FeatureEncodeTest, constructor)
258 loco::FeatureEncode feature_encode;
260 ASSERT_EQ(loco::CanonicalDialect::get(), feature_encode.dialect());
261 ASSERT_EQ(loco::CanonicalOpcode::FeatureEncode, feature_encode.opcode());
263 ASSERT_EQ(nullptr, feature_encode.input());
264 ASSERT_EQ(nullptr, feature_encode.encoder());
267 TEST(FeatureDecodeTest, constructor)
269 loco::FeatureDecode feature_decode;
271 ASSERT_EQ(loco::CanonicalDialect::get(), feature_decode.dialect());
272 ASSERT_EQ(loco::CanonicalOpcode::FeatureDecode, feature_decode.opcode());
274 ASSERT_EQ(nullptr, feature_decode.input());
275 ASSERT_EQ(nullptr, feature_decode.decoder());
278 TEST(Reshape_Fixed_Test, constructor)
280 loco::Reshape<loco::ReshapeType::Fixed> reshape;
282 ASSERT_EQ(loco::CanonicalDialect::get(), reshape.dialect());
283 ASSERT_EQ(loco::CanonicalOpcode::FixedReshape, reshape.opcode());
285 ASSERT_EQ(0, reshape.rank());
288 TEST(Reshape_Fixed_Test, shape)
290 loco::Reshape<loco::ReshapeType::Fixed> reshape;
291 reshape.shape({2, 3});
293 ASSERT_EQ(2, reshape.rank());
294 ASSERT_EQ(2, reshape.dim(0));
295 ASSERT_EQ(3, reshape.dim(1));
298 TEST(FilterEncodeTest, constructor)
300 loco::FilterEncode filter_encode;
302 ASSERT_EQ(loco::CanonicalDialect::get(), filter_encode.dialect());
303 ASSERT_EQ(loco::CanonicalOpcode::FilterEncode, filter_encode.opcode());
305 ASSERT_EQ(nullptr, filter_encode.input());
306 ASSERT_EQ(nullptr, filter_encode.encoder());
309 TEST(FilterDecodeTest, constructor)
311 loco::FilterDecode filter_decode;
313 ASSERT_EQ(loco::CanonicalDialect::get(), filter_decode.dialect());
314 ASSERT_EQ(loco::CanonicalOpcode::FilterDecode, filter_decode.opcode());
316 ASSERT_EQ(nullptr, filter_decode.input());
317 ASSERT_EQ(nullptr, filter_decode.decoder());
320 TEST(DepthwiseFilterEncodeTest, constructor)
322 loco::DepthwiseFilterEncode dw_filter_encode;
324 ASSERT_EQ(loco::CanonicalDialect::get(), dw_filter_encode.dialect());
325 ASSERT_EQ(loco::CanonicalOpcode::DepthwiseFilterEncode, dw_filter_encode.opcode());
327 ASSERT_EQ(nullptr, dw_filter_encode.input());
328 ASSERT_EQ(nullptr, dw_filter_encode.encoder());
331 TEST(DepthwiseFilterDecodeTest, constructor)
333 loco::DepthwiseFilterDecode dw_filter_decode;
335 ASSERT_EQ(loco::CanonicalDialect::get(), dw_filter_decode.dialect());
336 ASSERT_EQ(loco::CanonicalOpcode::DepthwiseFilterDecode, dw_filter_decode.opcode());
338 ASSERT_EQ(nullptr, dw_filter_decode.input());
339 ASSERT_EQ(nullptr, dw_filter_decode.decoder());
342 TEST(TensorConcatTest, constructor)
344 loco::TensorConcat tensor_concat;
346 ASSERT_EQ(loco::CanonicalDialect::get(), tensor_concat.dialect());
347 ASSERT_EQ(loco::CanonicalOpcode::TensorConcat, tensor_concat.opcode());
349 ASSERT_EQ(nullptr, tensor_concat.lhs());
350 ASSERT_EQ(nullptr, tensor_concat.rhs());
351 ASSERT_EQ(0, tensor_concat.axis());
353 tensor_concat.axis(3);
354 ASSERT_EQ(3, tensor_concat.axis());
357 TEST(Conv2DTest, constructor)
361 ASSERT_EQ(loco::CanonicalDialect::get(), conv2d.dialect());
362 ASSERT_EQ(loco::CanonicalOpcode::Conv2D, conv2d.opcode());
364 ASSERT_EQ(nullptr, conv2d.ifm());
365 ASSERT_EQ(nullptr, conv2d.ker());
367 ASSERT_NE(conv2d.pad(), nullptr);
368 ASSERT_EQ(0, conv2d.pad()->top());
369 ASSERT_EQ(0, conv2d.pad()->bottom());
370 ASSERT_EQ(0, conv2d.pad()->left());
371 ASSERT_EQ(0, conv2d.pad()->right());
373 ASSERT_NE(conv2d.stride(), nullptr);
374 ASSERT_EQ(1, conv2d.stride()->vertical());
375 ASSERT_EQ(1, conv2d.stride()->horizontal());
378 TEST(DepthwiseConv2DTest, constructor)
380 loco::DepthwiseConv2D dw_conv2d;
382 ASSERT_EQ(loco::CanonicalDialect::get(), dw_conv2d.dialect());
383 ASSERT_EQ(loco::CanonicalOpcode::DepthwiseConv2D, dw_conv2d.opcode());
385 ASSERT_EQ(nullptr, dw_conv2d.ifm());
386 ASSERT_EQ(nullptr, dw_conv2d.ker());
388 ASSERT_NE(dw_conv2d.pad(), nullptr);
389 ASSERT_EQ(0, dw_conv2d.pad()->top());
390 ASSERT_EQ(0, dw_conv2d.pad()->bottom());
391 ASSERT_EQ(0, dw_conv2d.pad()->left());
392 ASSERT_EQ(0, dw_conv2d.pad()->right());
394 ASSERT_NE(dw_conv2d.stride(), nullptr);
395 ASSERT_EQ(1, dw_conv2d.stride()->vertical());
396 ASSERT_EQ(1, dw_conv2d.stride()->horizontal());
399 TEST(TransposedConv2DTest, constructor)
401 loco::TransposedConv2D tr_conv2d;
403 ASSERT_EQ(loco::CanonicalDialect::get(), tr_conv2d.dialect());
404 ASSERT_EQ(loco::CanonicalOpcode::TransposedConv2D, tr_conv2d.opcode());
406 ASSERT_EQ(nullptr, tr_conv2d.ifm());
407 ASSERT_EQ(nullptr, tr_conv2d.ker());
409 ASSERT_NE(tr_conv2d.pad(), nullptr);
410 ASSERT_EQ(0, tr_conv2d.pad()->top());
411 ASSERT_EQ(0, tr_conv2d.pad()->bottom());
412 ASSERT_EQ(0, tr_conv2d.pad()->left());
413 ASSERT_EQ(0, tr_conv2d.pad()->right());
415 ASSERT_NE(tr_conv2d.stride(), nullptr);
416 ASSERT_EQ(1, tr_conv2d.stride()->vertical());
417 ASSERT_EQ(1, tr_conv2d.stride()->horizontal());
420 TEST(BiasEncodeTest, constructor)
422 loco::BiasEncode bias_encode;
424 ASSERT_EQ(loco::CanonicalDialect::get(), bias_encode.dialect());
425 ASSERT_EQ(loco::CanonicalOpcode::BiasEncode, bias_encode.opcode());
427 ASSERT_EQ(nullptr, bias_encode.input());
430 TEST(TensorBiasAddTest, constructor)
432 loco::BiasAdd<loco::Domain::Tensor> bias_add;
434 ASSERT_EQ(loco::CanonicalDialect::get(), bias_add.dialect());
435 ASSERT_EQ(loco::CanonicalOpcode::TensorBiasAdd, bias_add.opcode());
437 ASSERT_EQ(nullptr, bias_add.value());
438 ASSERT_EQ(nullptr, bias_add.bias());
439 ASSERT_EQ(0, bias_add.axis());
442 TEST(TensorBiasAddTest, alias)
444 loco::TensorBiasAdd bias_add;
449 TEST(FeatureBiasAddTest, constructor)
451 loco::BiasAdd<loco::Domain::Feature> bias_add;
453 ASSERT_EQ(loco::CanonicalDialect::get(), bias_add.dialect());
454 ASSERT_EQ(loco::CanonicalOpcode::FeatureBiasAdd, bias_add.opcode());
456 ASSERT_EQ(nullptr, bias_add.value());
457 ASSERT_EQ(nullptr, bias_add.bias());
460 TEST(FeatureBiasAddTest, alias)
462 loco::FeatureBiasAdd bias_add;
467 TEST(EltwiseAddTest, constructor)
469 loco::EltwiseAdd eltwise_add;
474 TEST(EltwiseMaxTest, constructor)
476 loco::EltwiseMax eltwise_max;
481 TEST(EltwiseMulTest, constructor)
483 loco::EltwiseMul eltwise_mul;
488 TEST(EltwiseSubTest, constructor)
490 loco::EltwiseSub eltwise_sub;
495 TEST(EltwiseDivTest, constructor)
497 loco::EltwiseDiv eltwise_div;
502 TEST(EltwiseSqrtTest, constructor)
504 loco::EltwiseSqrt sqrt_node;
506 ASSERT_EQ(loco::CanonicalDialect::get(), sqrt_node.dialect());
507 ASSERT_EQ(loco::CanonicalOpcode::EltwiseSqrt, sqrt_node.opcode());
509 ASSERT_EQ(nullptr, sqrt_node.input());
512 TEST(TensorBroadcastTest, constructor)
514 loco::TensorBroadcast tensor_broadcast_node;
516 ASSERT_EQ(loco::CanonicalDialect::get(), tensor_broadcast_node.dialect());
517 ASSERT_EQ(loco::CanonicalOpcode::TensorBroadcast, tensor_broadcast_node.opcode());
519 ASSERT_EQ(nullptr, tensor_broadcast_node.input());
522 TEST(TensorBroadcastTest, mapping)
524 loco::TensorBroadcast tensor_broadcast_node;
526 ASSERT_EQ(false, tensor_broadcast_node.mapping()->defined(0));
528 tensor_broadcast_node.mapping()->dim(0) = 3;
530 ASSERT_EQ(true, tensor_broadcast_node.mapping()->defined(0));
531 ASSERT_EQ(3, tensor_broadcast_node.mapping()->dim(0));
534 TEST(MatrixEncodeTest, constructor)
536 loco::MatrixEncode matrix_encode;
538 ASSERT_EQ(loco::CanonicalDialect::get(), matrix_encode.dialect());
539 ASSERT_EQ(loco::CanonicalOpcode::MatrixEncode, matrix_encode.opcode());
541 ASSERT_EQ(nullptr, matrix_encode.input());
544 TEST(MatrixDecodeTest, constructor)
546 loco::MatrixDecode matrix_decode;
548 ASSERT_EQ(loco::CanonicalDialect::get(), matrix_decode.dialect());
549 ASSERT_EQ(loco::CanonicalOpcode::MatrixDecode, matrix_decode.opcode());
551 ASSERT_EQ(nullptr, matrix_decode.input());
554 TEST(MatMulTest, constructor)
556 loco::MatMul mat_mul;
558 ASSERT_EQ(loco::CanonicalDialect::get(), mat_mul.dialect());
559 ASSERT_EQ(loco::CanonicalOpcode::MatMul, mat_mul.opcode());
561 ASSERT_EQ(nullptr, mat_mul.lhs());
562 ASSERT_EQ(nullptr, mat_mul.rhs());
565 TEST(TransposeTest, constructor)
567 loco::TensorTranspose transpose;
569 ASSERT_EQ(loco::CanonicalDialect::get(), transpose.dialect());
570 ASSERT_EQ(loco::CanonicalOpcode::TensorTranspose, transpose.opcode());
572 ASSERT_EQ(nullptr, transpose.input());
573 ASSERT_EQ(0, transpose.perm()->size());
576 TEST(TransposeTest, perm)
578 loco::TensorTranspose transpose;
580 transpose.perm()->size(3);
581 transpose.perm()->axis(0) = 1;
582 transpose.perm()->axis(1) = 2;
583 transpose.perm()->axis(2) = 0;
585 ASSERT_EQ(1, transpose.perm()->axis(0));
586 ASSERT_EQ(2, transpose.perm()->axis(1));
587 ASSERT_EQ(0, transpose.perm()->axis(2));