2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <logo/Phase.h>
19 #include <luci/test/TestIOGraph.h>
21 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
22 #include "luci/Pass/CircleShapeInferencePass.h"
24 #include <luci/IR/CircleNodes.h>
26 #include <gtest/gtest.h>
28 using namespace luci::test;
34 * Graph with a single Op (example: Add).
37 * - All Ops including Input/Output are NCHW.
46 * - All Ops including Input/Output are NHWC.
65 SimpleGraph() = default;
70 input = g.nodes()->create<luci::CircleInput>();
71 output = g.nodes()->create<luci::CircleOutput>();
73 output->name("output");
75 auto graph_input = g.inputs()->create();
76 input->index(graph_input->index());
77 auto graph_output = g.outputs()->create();
78 output->index(graph_output->index());
80 graph_input->dtype(loco::DataType::FLOAT32);
81 input->dtype(loco::DataType::FLOAT32);
82 output->dtype(loco::DataType::FLOAT32);
83 graph_output->dtype(loco::DataType::FLOAT32);
85 uint32_t channel_size = 16;
86 graph_input->shape({1, channel_size, 4, 4});
87 input->shape({1, channel_size, 4, 4});
88 output->shape({1, channel_size, 4, 4});
89 graph_output->shape({1, channel_size, 4, 4});
91 auto graph_body = insertGraphBody(input);
92 output->from(graph_body);
95 virtual ~SimpleGraph() = default;
98 virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
102 luci::CircleInput *input = nullptr;
103 luci::CircleOutput *output = nullptr;
106 class AddGraph final : public SimpleGraph
109 loco::Node *insertGraphBody(loco::Node *input) override
111 add = g.nodes()->create<luci::CircleAdd>();
112 beta = g.nodes()->create<luci::CircleConst>();
114 add->dtype(loco::DataType::FLOAT32);
115 beta->dtype(loco::DataType::FLOAT32);
117 uint32_t channel_size = 16;
118 add->shape({1, channel_size, 4, 4});
119 beta->shape({1, channel_size, 1, 1});
121 beta->size<loco::DataType::FLOAT32>(channel_size);
122 for (uint32_t i = 0; i < channel_size; i++)
124 beta->at<loco::DataType::FLOAT32>(i) = i;
137 void update_const_shape_to_nchw(void)
139 uint32_t channel_size = 16;
140 beta->shape({1, channel_size, 4, 4});
142 beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
143 for (uint32_t i = 0; i < channel_size; i++)
145 beta->at<loco::DataType::FLOAT32>(i) = i;
150 luci::CircleAdd *add = nullptr;
151 luci::CircleConst *beta = nullptr;
154 class NHWCReluGraph final : public SimpleGraph
157 loco::Node *insertGraphBody(loco::Node *input) override
159 relu = g.nodes()->create<luci::CircleRelu>();
160 pre_reshape = g.nodes()->create<luci::CircleReshape>();
161 post_reshape = g.nodes()->create<luci::CircleReshape>();
162 pre_shape = g.nodes()->create<luci::CircleConst>();
163 post_shape = g.nodes()->create<luci::CircleConst>();
165 pre_shape->dtype(loco::DataType::S32);
166 post_shape->dtype(loco::DataType::S32);
168 uint32_t channel_size = 16;
169 auto in = loco::must_cast<luci::CircleNode *>(input);
170 in->shape({1, channel_size, 4, 4});
171 pre_shape->shape({4});
172 post_shape->shape({4});
174 pre_shape->size<loco::DataType::S32>(4);
175 pre_shape->at<loco::DataType::S32>(0) = 1;
176 pre_shape->at<loco::DataType::S32>(1) = 4;
177 pre_shape->at<loco::DataType::S32>(2) = 4;
178 pre_shape->at<loco::DataType::S32>(3) = channel_size;
180 post_shape->size<loco::DataType::S32>(4);
181 post_shape->at<loco::DataType::S32>(0) = 1;
182 post_shape->at<loco::DataType::S32>(1) = channel_size;
183 post_shape->at<loco::DataType::S32>(2) = 4;
184 post_shape->at<loco::DataType::S32>(3) = 4;
186 pre_reshape->tensor(input);
187 pre_reshape->shape(pre_shape);
189 relu->features(pre_reshape);
191 post_reshape->tensor(relu);
192 post_reshape->shape(post_shape);
195 pre_reshape->name("pre-reshape");
196 post_reshape->name("post-reshape");
202 luci::CircleRelu *relu = nullptr;
203 luci::CircleReshape *pre_reshape = nullptr;
204 luci::CircleReshape *post_reshape = nullptr;
205 luci::CircleConst *pre_shape = nullptr;
206 luci::CircleConst *post_shape = nullptr;
210 * Graph with pre-Reshape but no post-Transpose/Reshape.
234 class NoPostReshapeGraph final : public SimpleGraph
237 loco::Node *insertGraphBody(loco::Node *input) override
239 relu = g.nodes()->create<luci::CircleRelu>();
240 pre_reshape = g.nodes()->create<luci::CircleReshape>();
241 pre_shape = g.nodes()->create<luci::CircleConst>();
243 pre_shape->dtype(loco::DataType::S32);
245 uint32_t channel_size = 16;
246 auto in = loco::must_cast<luci::CircleNode *>(input);
247 in->shape({1, channel_size, 4, 4});
248 pre_shape->shape({4});
250 pre_shape->size<loco::DataType::S32>(4);
251 pre_shape->at<loco::DataType::S32>(0) = 1;
252 pre_shape->at<loco::DataType::S32>(1) = 4;
253 pre_shape->at<loco::DataType::S32>(2) = 4;
254 pre_shape->at<loco::DataType::S32>(3) = channel_size;
256 pre_reshape->tensor(input);
257 pre_reshape->shape(pre_shape);
258 relu->features(pre_reshape);
261 pre_reshape->name("pre-reshape");
267 luci::CircleRelu *relu = nullptr;
268 luci::CircleReshape *pre_reshape = nullptr;
269 luci::CircleConst *pre_shape = nullptr;
273 * Graph with two pre-Reshapes
305 class ReluNotClosedGraph final : public SimpleGraph
308 loco::Node *insertGraphBody(loco::Node *input) override
310 relu = g.nodes()->create<luci::CircleRelu>();
311 pre_reshape = g.nodes()->create<luci::CircleReshape>();
312 pre_reshape_2 = g.nodes()->create<luci::CircleReshape>();
313 post_reshape = g.nodes()->create<luci::CircleReshape>();
314 pre_shape = g.nodes()->create<luci::CircleConst>();
315 pre_shape_2 = g.nodes()->create<luci::CircleConst>();
316 post_shape = g.nodes()->create<luci::CircleConst>();
318 pre_shape->dtype(loco::DataType::S32);
319 pre_shape_2->dtype(loco::DataType::S32);
320 post_shape->dtype(loco::DataType::S32);
322 uint32_t channel_size = 16;
323 auto in = loco::must_cast<luci::CircleNode *>(input);
324 in->shape({1, channel_size, 4, 4});
325 pre_shape->shape({4});
326 pre_shape_2->shape({4});
327 post_shape->shape({4});
329 pre_shape->size<loco::DataType::S32>(4);
330 pre_shape->at<loco::DataType::S32>(0) = 1;
331 pre_shape->at<loco::DataType::S32>(1) = 4;
332 pre_shape->at<loco::DataType::S32>(2) = 4;
333 pre_shape->at<loco::DataType::S32>(3) = channel_size;
335 pre_shape_2->size<loco::DataType::S32>(4);
336 pre_shape_2->at<loco::DataType::S32>(0) = 1;
337 pre_shape_2->at<loco::DataType::S32>(1) = 4;
338 pre_shape_2->at<loco::DataType::S32>(2) = channel_size;
339 pre_shape_2->at<loco::DataType::S32>(3) = 4;
341 post_shape->size<loco::DataType::S32>(4);
342 post_shape->at<loco::DataType::S32>(0) = 1;
343 post_shape->at<loco::DataType::S32>(1) = 4;
344 post_shape->at<loco::DataType::S32>(2) = 4;
345 post_shape->at<loco::DataType::S32>(3) = channel_size;
347 pre_reshape->tensor(input);
348 pre_reshape->shape(pre_shape);
350 relu->features(pre_reshape);
352 pre_reshape_2->tensor(relu);
353 pre_reshape_2->shape(pre_shape_2);
355 post_reshape->tensor(pre_reshape_2);
356 post_reshape->shape(post_shape);
359 pre_reshape->name("pre-reshape");
360 pre_reshape->name("pre-reshape-2");
361 post_reshape->name("post-reshape");
367 luci::CircleRelu *relu = nullptr;
368 luci::CircleReshape *pre_reshape = nullptr;
369 luci::CircleReshape *pre_reshape_2 = nullptr;
370 luci::CircleReshape *post_reshape = nullptr;
371 luci::CircleConst *pre_shape = nullptr;
372 luci::CircleConst *pre_shape_2 = nullptr;
373 luci::CircleConst *post_shape = nullptr;
376 class AddScalarGraph final : public SimpleGraph
379 loco::Node *insertGraphBody(loco::Node *input) override
381 add = g.nodes()->create<luci::CircleAdd>();
382 beta = g.nodes()->create<luci::CircleConst>();
384 add->dtype(loco::DataType::FLOAT32);
385 beta->dtype(loco::DataType::FLOAT32);
387 uint32_t channel_size = 16;
388 add->shape({1, channel_size, 4, 4});
391 beta->size<loco::DataType::FLOAT32>(1);
392 beta->at<loco::DataType::FLOAT32>(0) = 3.14;
404 luci::CircleAdd *add = nullptr;
405 luci::CircleConst *beta = nullptr;
408 class ConcatenationGraph final : public SimpleGraph
411 loco::Node *insertGraphBody(loco::Node *input) override
413 concat = g.nodes()->create<luci::CircleConcatenation>(2);
414 concat->values(0, input);
417 input2 = g.nodes()->create<luci::CircleConst>();
418 input2->dtype(loco::DataType::FLOAT32);
419 input2->shape({1, 16, 4, 4});
420 input2->size<loco::DataType::FLOAT32>(16 * 4 * 4);
421 for (uint32_t i = 0; i < 16 * 4 * 4; i++)
423 input2->at<loco::DataType::FLOAT32>(i) = i;
425 concat->values(1, input2);
427 concat->name("concat");
428 input2->name("input2");
434 luci::CircleConcatenation *concat = nullptr;
435 luci::CircleConst *input2 = nullptr;
438 class EluGraph final : public SimpleGraph
441 loco::Node *insertGraphBody(loco::Node *input) override
443 elu = g.nodes()->create<luci::CircleElu>();
444 elu->features(input);
451 luci::CircleElu *elu = nullptr;
454 class LeakyReluGraph final : public SimpleGraph
457 loco::Node *insertGraphBody(loco::Node *input) override
459 leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>();
460 leakyrelu->features(input);
461 leakyrelu->name("leakyrelu");
467 luci::CircleLeakyRelu *leakyrelu = nullptr;
470 class LogisticGraph final : public SimpleGraph
473 loco::Node *insertGraphBody(loco::Node *input) override
475 logistic = g.nodes()->create<luci::CircleLogistic>();
477 logistic->name("logistic");
483 luci::CircleLogistic *logistic = nullptr;
486 class MaximumGraph final : public SimpleGraph
489 loco::Node *insertGraphBody(loco::Node *input) override
491 max = g.nodes()->create<luci::CircleMaximum>();
492 limit = g.nodes()->create<luci::CircleConst>();
494 max->dtype(loco::DataType::FLOAT32);
495 limit->dtype(loco::DataType::FLOAT32);
497 max->shape({1, 16, 4, 4});
500 limit->size<loco::DataType::FLOAT32>(1);
501 limit->at<loco::DataType::FLOAT32>(0) = 100;
507 limit->name("limit");
513 luci::CircleMaximum *max = nullptr;
514 luci::CircleConst *limit = nullptr;
517 class MaximumNonConstGraph final : public SimpleGraph
520 loco::Node *insertGraphBody(loco::Node *input) override
522 max = g.nodes()->create<luci::CircleMaximum>();
523 max->dtype(loco::DataType::FLOAT32);
524 max->shape({1, 16, 4, 4});
535 luci::CircleMaximum *max = nullptr;
538 class MeanGraph final : public SimpleGraph
541 loco::Node *insertGraphBody(loco::Node *input) override
543 mean = g.nodes()->create<luci::CircleMean>();
544 rindices = g.nodes()->create<luci::CircleConst>();
546 mean->dtype(loco::DataType::FLOAT32);
547 rindices->dtype(loco::DataType::S32);
550 rindices->shape({static_cast<uint32_t>(_axes.size())});
552 rindices->size<loco::DataType::S32>(_axes.size());
553 for (uint32_t i = 0; i < _axes.size(); ++i)
555 rindices->at<loco::DataType::S32>(i) = _axes[i];
559 mean->reduction_indices(rindices);
560 mean->keep_dims(_keep_dims);
563 rindices->name("rindices");
569 void keep_dims(bool val) { _keep_dims = val; }
570 void axes(std::vector<int32_t> val) { _axes = val; }
571 void shape(std::initializer_list<uint32_t> val) { _shape = val; }
574 luci::CircleMean *mean = nullptr;
575 luci::CircleConst *rindices = nullptr;
578 bool _keep_dims = true;
579 std::vector<int32_t> _axes = {2, 3};
580 std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
583 class MinimumGraph final : public SimpleGraph
586 loco::Node *insertGraphBody(loco::Node *input) override
588 min = g.nodes()->create<luci::CircleMinimum>();
589 limit = g.nodes()->create<luci::CircleConst>();
591 min->dtype(loco::DataType::FLOAT32);
592 limit->dtype(loco::DataType::FLOAT32);
594 min->shape({1, 16, 4, 4});
597 limit->size<loco::DataType::FLOAT32>(1);
598 limit->at<loco::DataType::FLOAT32>(0) = 100;
604 limit->name("limit");
610 luci::CircleMinimum *min = nullptr;
611 luci::CircleConst *limit = nullptr;
614 class MulGraph final : public SimpleGraph
617 loco::Node *insertGraphBody(loco::Node *input) override
619 mul = g.nodes()->create<luci::CircleMul>();
620 multiplier = g.nodes()->create<luci::CircleConst>();
622 mul->dtype(loco::DataType::FLOAT32);
623 multiplier->dtype(loco::DataType::FLOAT32);
625 uint32_t channel_size = 16;
626 mul->shape({1, channel_size, 4, 4});
627 multiplier->shape({1, channel_size, 1, 1});
629 multiplier->size<loco::DataType::FLOAT32>(channel_size);
630 for (uint32_t i = 0; i < channel_size; i++)
632 multiplier->at<loco::DataType::FLOAT32>(i) = i;
639 multiplier->name("multiplier");
645 void update_const_shape_to_nchw(void)
647 uint32_t channel_size = 16;
648 multiplier->shape({1, channel_size, 4, 4});
650 multiplier->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
651 for (uint32_t i = 0; i < channel_size; i++)
653 multiplier->at<loco::DataType::FLOAT32>(i) = i;
658 luci::CircleMul *mul = nullptr;
659 luci::CircleConst *multiplier = nullptr;
662 class MulScalarGraph final : public SimpleGraph
665 loco::Node *insertGraphBody(loco::Node *input) override
667 mul = g.nodes()->create<luci::CircleMul>();
668 multiplier = g.nodes()->create<luci::CircleConst>();
670 mul->dtype(loco::DataType::FLOAT32);
671 multiplier->dtype(loco::DataType::FLOAT32);
673 uint32_t channel_size = 16;
674 mul->shape({1, channel_size, 4, 4});
675 multiplier->shape({1});
677 multiplier->size<loco::DataType::FLOAT32>(1);
678 multiplier->at<loco::DataType::FLOAT32>(0) = 2;
684 multiplier->name("multiplier");
690 luci::CircleMul *mul = nullptr;
691 luci::CircleConst *multiplier = nullptr;
694 class MulBothNormGraph final : public SimpleGraph
697 loco::Node *insertGraphBody(loco::Node *input) override
699 mul = g.nodes()->create<luci::CircleMul>();
701 mul->dtype(loco::DataType::FLOAT32);
703 uint32_t channel_size = 16;
704 mul->shape({1, channel_size, 4, 4});
715 luci::CircleMul *mul = nullptr;
718 class NegGraph final : public SimpleGraph
721 loco::Node *insertGraphBody(loco::Node *input) override
723 neg = g.nodes()->create<luci::CircleNeg>();
731 luci::CircleNeg *neg = nullptr;
734 class PadGraph final : public SimpleGraph
737 loco::Node *insertGraphBody(loco::Node *input) override
739 pad = g.nodes()->create<luci::CirclePad>();
740 paddings = g.nodes()->create<luci::CircleConst>();
742 pad->dtype(loco::DataType::FLOAT32);
743 paddings->dtype(loco::DataType::S32);
745 uint32_t channel_size = 16;
746 pad->shape({1, channel_size, 4, 4});
747 paddings->shape({4, 2});
749 // paddings data (NCHW)
750 // [[0,0], [0,0], [1,1], [2,2]]
751 paddings->size<loco::DataType::S32>(8);
752 for (uint32_t dim = 0; dim < 4; dim++)
754 for (uint32_t i = 0; i < 2; i++)
763 paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
768 pad->paddings(paddings);
771 paddings->name("paddings");
777 luci::CirclePad *pad = nullptr;
778 luci::CircleConst *paddings = nullptr;
781 class PadV2Graph final : public SimpleGraph
784 loco::Node *insertGraphBody(loco::Node *input) override
786 pad = g.nodes()->create<luci::CirclePadV2>();
787 paddings = g.nodes()->create<luci::CircleConst>();
788 const_value = g.nodes()->create<luci::CircleConst>();
790 pad->dtype(loco::DataType::FLOAT32);
791 paddings->dtype(loco::DataType::S32);
792 const_value->dtype(loco::DataType::FLOAT32);
794 uint32_t channel_size = 16;
795 pad->shape({1, channel_size, 4, 4});
796 paddings->shape({4, 2});
797 const_value->shape({1});
799 // paddings data (NCHW)
800 // [[0,0], [0,0], [1,1], [2,2]]
801 paddings->size<loco::DataType::S32>(8);
802 for (uint32_t dim = 0; dim < 4; dim++)
804 for (uint32_t i = 0; i < 2; i++)
813 paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
817 const_value->size<loco::DataType::FLOAT32>(1);
818 const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
821 pad->paddings(paddings);
822 pad->constant_values(paddings);
825 paddings->name("paddings");
826 const_value->name("constant_values");
832 luci::CirclePadV2 *pad = nullptr;
833 luci::CircleConst *paddings = nullptr;
834 luci::CircleConst *const_value = nullptr;
837 class ReduceMaxGraph final : public SimpleGraph
840 loco::Node *insertGraphBody(loco::Node *input) override
842 rm = g.nodes()->create<luci::CircleReduceMax>();
843 rindices = g.nodes()->create<luci::CircleConst>();
845 rm->dtype(loco::DataType::FLOAT32);
846 rindices->dtype(loco::DataType::S32);
849 rindices->shape({static_cast<uint32_t>(_axes.size())});
851 rindices->size<loco::DataType::S32>(_axes.size());
852 for (uint32_t i = 0; i < _axes.size(); ++i)
854 rindices->at<loco::DataType::S32>(i) = _axes[i];
858 rm->reduction_indices(rindices);
859 rm->keep_dims(_keep_dims);
861 rm->name("reduce_max");
862 rindices->name("rindices");
868 void keep_dims(bool val) { _keep_dims = val; }
869 void axes(std::vector<int32_t> val) { _axes = val; }
870 void shape(std::initializer_list<uint32_t> val) { _shape = val; }
873 luci::CircleReduceMax *rm = nullptr;
874 luci::CircleConst *rindices = nullptr;
877 bool _keep_dims = true;
878 std::vector<int32_t> _axes = {2, 3};
879 std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
882 class ReduceMinGraph final : public SimpleGraph
885 loco::Node *insertGraphBody(loco::Node *input) override
887 rm = g.nodes()->create<luci::CircleReduceMin>();
888 rindices = g.nodes()->create<luci::CircleConst>();
890 rm->dtype(loco::DataType::FLOAT32);
891 rindices->dtype(loco::DataType::S32);
894 rindices->shape({static_cast<uint32_t>(_axes.size())});
896 rindices->size<loco::DataType::S32>(_axes.size());
897 for (uint32_t i = 0; i < _axes.size(); ++i)
899 rindices->at<loco::DataType::S32>(i) = _axes[i];
903 rm->reduction_indices(rindices);
904 rm->keep_dims(_keep_dims);
906 rm->name("reduce_max");
907 rindices->name("rindices");
913 void keep_dims(bool val) { _keep_dims = val; }
914 void axes(std::vector<int32_t> val) { _axes = val; }
915 void shape(std::initializer_list<uint32_t> val) { _shape = val; }
918 luci::CircleReduceMin *rm = nullptr;
919 luci::CircleConst *rindices = nullptr;
922 bool _keep_dims = true;
923 std::vector<int32_t> _axes = {2, 3};
924 std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
927 class ReluGraph final : public SimpleGraph
930 loco::Node *insertGraphBody(loco::Node *input) override
932 relu = g.nodes()->create<luci::CircleRelu>();
933 relu->features(input);
940 luci::CircleRelu *relu = nullptr;
943 class Relu6Graph final : public SimpleGraph
946 loco::Node *insertGraphBody(loco::Node *input) override
948 relu6 = g.nodes()->create<luci::CircleRelu6>();
949 relu6->features(input);
950 relu6->name("relu6");
956 luci::CircleRelu6 *relu6 = nullptr;
959 class RsqrtGraph final : public SimpleGraph
962 loco::Node *insertGraphBody(loco::Node *input) override
964 rsqrt = g.nodes()->create<luci::CircleRsqrt>();
966 rsqrt->name("rsqrt");
972 luci::CircleRsqrt *rsqrt = nullptr;
978 SplitVGraphlet() = default;
981 void init(loco::Graph *g)
983 // CircleCustom(SplitV)
984 _splitv = g->nodes()->create<luci::CircleSplitV>();
985 _splitv->shape({1, 2, 2, 192});
986 _splitv->dtype(loco::DataType::FLOAT32);
987 _splitv->name("splitv");
990 auto size_splits = g->nodes()->create<luci::CircleConst>();
991 size_splits->dtype(loco::DataType::S32);
992 size_splits->shape({3});
993 size_splits->size<loco::DataType::S32>(3);
994 size_splits->at<loco::DataType::S32>(0) = 32;
995 size_splits->at<loco::DataType::S32>(1) = 32;
996 size_splits->at<loco::DataType::S32>(2) = 128;
999 auto split_dim = g->nodes()->create<luci::CircleConst>();
1000 split_dim->dtype(loco::DataType::S32);
1002 split_dim->size<loco::DataType::S32>(1);
1003 split_dim->scalar<loco::DataType::S32>() = 3;
1005 _splitv->size_splits(size_splits);
1006 _splitv->split_dim(split_dim);
1007 _splitv->num_split(3);
1010 _splitv_out1 = g->nodes()->create<luci::CircleSplitVOut>();
1011 _splitv_out1->shape({1, 2, 2, 32});
1012 _splitv_out1->dtype(loco::DataType::FLOAT32);
1013 _splitv_out1->index(0);
1014 _splitv_out1->input(_splitv);
1015 _splitv_out1->name("splitv_out1");
1018 _splitv_out2 = g->nodes()->create<luci::CircleSplitVOut>();
1019 _splitv_out2->shape({1, 2, 2, 32});
1020 _splitv_out2->dtype(loco::DataType::FLOAT32);
1021 _splitv_out2->index(1);
1022 _splitv_out2->input(_splitv);
1023 _splitv_out2->name("splitv_out2");
1026 _splitv_out3 = g->nodes()->create<luci::CircleSplitVOut>();
1027 _splitv_out3->shape({1, 2, 2, 128});
1028 _splitv_out3->dtype(loco::DataType::FLOAT32);
1029 _splitv_out3->index(2);
1030 _splitv_out3->input(_splitv);
1031 _splitv_out3->name("splitv_out3");
1035 luci::CircleSplitV *splitv() { return _splitv; }
1038 luci::CircleSplitV *_splitv = nullptr;
1039 luci::CircleSplitVOut *_splitv_out1 = nullptr;
1040 luci::CircleSplitVOut *_splitv_out2 = nullptr;
1041 luci::CircleSplitVOut *_splitv_out3 = nullptr;
1044 class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
1047 SplitVGraph() = default;
1051 TestIGraphlet::init(g(), {1, 2, 2, 192});
1052 TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
1053 SplitVGraphlet::init(g());
1056 _splitv->input(input());
1058 output(0)->from(_splitv_out1);
1059 output(1)->from(_splitv_out2);
1060 output(2)->from(_splitv_out3);
1064 class SquaredDifferenceGraph final : public SimpleGraph
1067 loco::Node *insertGraphBody(loco::Node *input) override
1069 sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
1072 sqdiff->name("sqdiff");
1078 luci::CircleSquaredDifference *sqdiff = nullptr;
1081 class SubGraph final : public SimpleGraph
1084 loco::Node *insertGraphBody(loco::Node *input) override
1086 sub = g.nodes()->create<luci::CircleSub>();
1087 beta = g.nodes()->create<luci::CircleConst>();
1089 sub->dtype(loco::DataType::FLOAT32);
1090 beta->dtype(loco::DataType::FLOAT32);
1092 uint32_t channel_size = 16;
1093 sub->shape({1, channel_size, 4, 4});
1094 beta->shape({1, channel_size, 1, 1});
1096 beta->size<loco::DataType::FLOAT32>(channel_size);
1097 for (uint32_t i = 0; i < channel_size; i++)
1099 beta->at<loco::DataType::FLOAT32>(i) = i;
1112 void update_const_shape_to_nchw(void)
1114 uint32_t channel_size = 16;
1115 beta->shape({1, channel_size, 4, 4});
1117 beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
1118 for (uint32_t i = 0; i < channel_size; i++)
1120 beta->at<loco::DataType::FLOAT32>(i) = i;
1125 luci::CircleSub *sub = nullptr;
1126 luci::CircleConst *beta = nullptr;
1129 class SubScalarGraph final : public SimpleGraph
1132 loco::Node *insertGraphBody(loco::Node *input) override
1134 sub = g.nodes()->create<luci::CircleSub>();
1135 beta = g.nodes()->create<luci::CircleConst>();
1137 sub->dtype(loco::DataType::FLOAT32);
1138 beta->dtype(loco::DataType::FLOAT32);
1140 uint32_t channel_size = 16;
1141 sub->shape({1, channel_size, 4, 4});
1144 beta->size<loco::DataType::FLOAT32>(1);
1145 beta->at<loco::DataType::FLOAT32>(0) = 5;
1157 luci::CircleSub *sub = nullptr;
1158 luci::CircleConst *beta = nullptr;
1161 void check_pre_trans(loco::Node *node)
1163 auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
1164 EXPECT_NE(nullptr, pre_trans);
1165 auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm());
1166 EXPECT_NE(nullptr, pre_trans_perm);
1167 EXPECT_EQ(1, pre_trans_perm->rank());
1168 EXPECT_EQ(4, pre_trans_perm->dim(0).value());
1169 EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype());
1170 EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0));
1171 EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1));
1172 EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2));
1173 EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3));
1176 void check_post_trans(loco::Node *node)
1178 auto post_trans = dynamic_cast<luci::CircleTranspose *>(node);
1179 EXPECT_NE(nullptr, post_trans);
1180 auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm());
1181 EXPECT_NE(nullptr, post_trans_perm);
1182 EXPECT_EQ(1, post_trans_perm->rank());
1183 EXPECT_EQ(4, post_trans_perm->dim(0).value());
1184 EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype());
1185 EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0));
1186 EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1));
1187 EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2));
1188 EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3));
1191 void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output)
1196 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
1200 std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
1202 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
1203 phase_runner.run(phase);
1208 TEST(ConvertNCHWToNHWCPassTest, name)
1210 luci::ConvertNCHWToNHWCPass pass(false, false);
1211 auto const name = pass.name();
1212 ASSERT_NE(nullptr, name);
1215 TEST(ConvertNCHWToNHWC, Add)
1220 run_phase(&g.g, false, false);
1222 auto input_succs = loco::succs(g.input);
1223 EXPECT_EQ(1, input_succs.size());
1224 check_post_trans(*input_succs.begin());
1226 check_pre_trans(g.add->x());
1228 auto add_succs = loco::succs(g.add);
1229 EXPECT_EQ(1, add_succs.size());
1230 check_post_trans(*add_succs.begin());
1232 uint32_t channel_size = 16;
1233 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1234 EXPECT_NE(nullptr, new_beta);
1235 EXPECT_EQ(4, new_beta->rank());
1236 EXPECT_EQ(1, new_beta->dim(0).value());
1237 EXPECT_EQ(1, new_beta->dim(1).value());
1238 EXPECT_EQ(1, new_beta->dim(2).value());
1239 EXPECT_EQ(channel_size, new_beta->dim(3).value());
1241 check_pre_trans(g.output->from());
1244 TEST(ConvertNCHWToNHWC, Add_NCHW_const)
1248 g.update_const_shape_to_nchw();
1250 run_phase(&g.g, false, false);
1252 check_pre_trans(g.add->x());
1254 auto add_succs = loco::succs(g.add);
1255 EXPECT_EQ(1, add_succs.size());
1256 check_post_trans(*add_succs.begin());
1258 uint32_t channel_size = 16;
1259 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1260 EXPECT_NE(nullptr, new_beta);
1261 EXPECT_EQ(4, new_beta->rank());
1262 EXPECT_EQ(1, new_beta->dim(0).value());
1263 EXPECT_EQ(4, new_beta->dim(1).value());
1264 EXPECT_EQ(4, new_beta->dim(2).value());
1265 EXPECT_EQ(channel_size, new_beta->dim(3).value());
1268 TEST(ConvertNCHWToNHWC, NHWC_Relu)
1270 // Relu is already NHWC, so it should not be converted
1271 // i.e., the graph is not changed
1275 run_phase(&g.g, false, false);
1277 EXPECT_EQ(g.pre_reshape, g.relu->features());
1279 auto relu_succs = loco::succs(g.relu);
1280 EXPECT_EQ(1, relu_succs.size());
1281 EXPECT_EQ(g.post_reshape, *relu_succs.begin());
1284 TEST(ConvertNCHWToNHWC, AddScalar)
1289 run_phase(&g.g, false, false);
1291 auto input_succs = loco::succs(g.input);
1292 EXPECT_EQ(1, input_succs.size());
1293 check_post_trans(*input_succs.begin());
1295 check_pre_trans(g.add->x());
1297 auto add_succs = loco::succs(g.add);
1298 EXPECT_EQ(1, add_succs.size());
1299 check_post_trans(*add_succs.begin());
1301 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1302 EXPECT_NE(nullptr, new_beta);
1303 EXPECT_EQ(4, new_beta->rank());
1304 EXPECT_EQ(1, new_beta->dim(0).value());
1305 EXPECT_EQ(1, new_beta->dim(1).value());
1306 EXPECT_EQ(1, new_beta->dim(2).value());
1307 EXPECT_EQ(1, new_beta->dim(3).value());
1309 check_pre_trans(g.output->from());
1312 TEST(ConvertNCHWToNHWC, Concatenation)
1314 ConcatenationGraph g;
1317 run_phase(&g.g, true, true);
1319 check_pre_trans(g.concat->values(0));
1320 check_pre_trans(g.concat->values(1));
1322 auto concat_succs = loco::succs(g.concat);
1323 EXPECT_EQ(1, concat_succs.size());
1324 check_post_trans(*concat_succs.begin());
1326 // Check concat shape, axis
1327 EXPECT_EQ(1, g.concat->dim(0).value());
1328 EXPECT_EQ(4, g.concat->dim(1).value());
1329 EXPECT_EQ(4, g.concat->dim(2).value());
1330 EXPECT_EQ(32, g.concat->dim(3).value());
1331 EXPECT_EQ(3, g.concat->axis());
1334 TEST(ConvertNCHWToNHWC, Elu)
1339 run_phase(&g.g, true, true);
1341 check_pre_trans(g.elu->features());
1343 auto elu_succs = loco::succs(g.elu);
1344 EXPECT_EQ(1, elu_succs.size());
1345 check_post_trans(*elu_succs.begin());
1348 EXPECT_EQ(1, g.elu->dim(0).value());
1349 EXPECT_EQ(4, g.elu->dim(1).value());
1350 EXPECT_EQ(4, g.elu->dim(2).value());
1351 EXPECT_EQ(16, g.elu->dim(3).value());
1354 TEST(ConvertNCHWToNHWC, LeakyRelu)
1359 run_phase(&g.g, true, true);
1361 check_pre_trans(g.leakyrelu->features());
1363 auto leakyrelu_succs = loco::succs(g.leakyrelu);
1364 EXPECT_EQ(1, leakyrelu_succs.size());
1365 check_post_trans(*leakyrelu_succs.begin());
1367 // Check leakyrelu shape
1368 EXPECT_EQ(1, g.leakyrelu->dim(0).value());
1369 EXPECT_EQ(4, g.leakyrelu->dim(1).value());
1370 EXPECT_EQ(4, g.leakyrelu->dim(2).value());
1371 EXPECT_EQ(16, g.leakyrelu->dim(3).value());
1374 TEST(ConvertNCHWToNHWC, Logistic)
1379 run_phase(&g.g, true, true);
1381 check_pre_trans(g.logistic->x());
1383 auto logistic_succs = loco::succs(g.logistic);
1384 EXPECT_EQ(1, logistic_succs.size());
1385 check_post_trans(*logistic_succs.begin());
1387 // Check logistic shape
1388 EXPECT_EQ(1, g.logistic->dim(0).value());
1389 EXPECT_EQ(4, g.logistic->dim(1).value());
1390 EXPECT_EQ(4, g.logistic->dim(2).value());
1391 EXPECT_EQ(16, g.logistic->dim(3).value());
1394 TEST(ConvertNCHWToNHWC, Maximum)
1399 run_phase(&g.g, false, false);
1401 auto input_succs = loco::succs(g.input);
1402 EXPECT_EQ(1, input_succs.size());
1403 check_post_trans(*input_succs.begin());
1405 check_pre_trans(g.max->x());
1407 auto max_succs = loco::succs(g.max);
1408 EXPECT_EQ(1, max_succs.size());
1409 check_post_trans(*max_succs.begin());
1411 check_pre_trans(g.output->from());
1414 TEST(ConvertNCHWToNHWC, Maximum_non_scalar_NEG)
1419 g.limit->shape({3});
1421 luci::ConvertNCHWToNHWCPass pass(true, true);
1422 EXPECT_FALSE(pass.run(&g.g));
1425 TEST(ConvertNCHWToNHWC, MaximumNonConst)
1427 MaximumNonConstGraph g;
1430 run_phase(&g.g, true, true);
1432 check_pre_trans(g.max->x());
1433 check_pre_trans(g.max->y());
1435 auto max_succs = loco::succs(g.max);
1436 EXPECT_EQ(1, max_succs.size());
1437 check_post_trans(*max_succs.begin());
1440 TEST(ConvertNCHWToNHWC, Mean)
1445 run_phase(&g.g, false, false);
1447 check_pre_trans(g.mean->input());
1449 auto mean_succs = loco::succs(g.mean);
1450 EXPECT_EQ(1, mean_succs.size());
1451 check_post_trans(*mean_succs.begin());
1453 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1454 EXPECT_NE(nullptr, new_rindices);
1455 EXPECT_EQ(1, new_rindices->rank());
1456 EXPECT_EQ(2, new_rindices->dim(0).value());
1457 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1458 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1459 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1462 TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
1466 std::vector<int32_t> nchw_ind;
1467 std::vector<int32_t> nhwc_ind;
1468 std::initializer_list<uint32_t> shape;
1469 bool needs_transpose = false;
1477 std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
1478 {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
1479 {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
1480 {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
1481 {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
1482 {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1484 for (auto &tc : test_cases)
1488 g.axes(tc.nchw_ind);
1492 run_phase(&g.g, false, true);
1494 check_pre_trans(g.mean->input());
1496 auto mean_succs = loco::succs(g.mean);
1497 EXPECT_EQ(1, mean_succs.size());
1498 if (tc.needs_transpose)
1500 EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
1504 EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
1507 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1508 EXPECT_NE(nullptr, new_rindices);
1509 EXPECT_EQ(1, new_rindices->rank());
1510 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1511 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1512 for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1514 EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1519 TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
1522 auto input = g.nodes()->create<luci::CircleInput>();
1523 auto output = g.nodes()->create<luci::CircleOutput>();
1524 input->name("input");
1525 output->name("output");
1527 auto graph_input = g.inputs()->create();
1528 input->index(graph_input->index());
1529 auto graph_output = g.outputs()->create();
1530 output->index(graph_output->index());
1532 graph_input->dtype(loco::DataType::FLOAT32);
1533 input->dtype(loco::DataType::FLOAT32);
1534 output->dtype(loco::DataType::FLOAT32);
1535 graph_output->dtype(loco::DataType::FLOAT32);
1537 uint32_t channel_size = 16;
1538 graph_input->shape({channel_size, 4, 4});
1539 input->shape({channel_size, 4, 4});
1540 output->shape({channel_size});
1541 graph_output->shape({channel_size});
1543 auto mean = g.nodes()->create<luci::CircleMean>();
1544 auto rindices = g.nodes()->create<luci::CircleConst>();
1546 mean->dtype(loco::DataType::FLOAT32);
1547 rindices->dtype(loco::DataType::S32);
1549 mean->shape({channel_size});
1550 rindices->shape({2});
1552 rindices->size<loco::DataType::S32>(2);
1553 rindices->at<loco::DataType::S32>(0) = 1;
1554 rindices->at<loco::DataType::S32>(1) = 2;
1557 mean->reduction_indices(rindices);
1558 mean->keep_dims(false);
1561 rindices->name("rindices");
1565 run_phase(&g, true, true);
1567 auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
1568 EXPECT_NE(nullptr, new_rindices);
1569 EXPECT_EQ(1, new_rindices->rank());
1570 EXPECT_EQ(2, new_rindices->dim(0).value());
1571 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1572 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1573 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1576 TEST(ConvertNCHWToNHWC, Minimum)
1581 run_phase(&g.g, false, false);
1583 auto input_succs = loco::succs(g.input);
1584 EXPECT_EQ(1, input_succs.size());
1585 check_post_trans(*input_succs.begin());
1587 check_pre_trans(g.min->x());
1589 auto min_succs = loco::succs(g.min);
1590 EXPECT_EQ(1, min_succs.size());
1591 check_post_trans(*min_succs.begin());
1593 check_pre_trans(g.output->from());
1596 TEST(ConvertNCHWToNHWC, Minimum_non_scalar_NEG)
1601 g.limit->shape({3});
1603 luci::ConvertNCHWToNHWCPass pass(true, true);
1604 EXPECT_FALSE(pass.run(&g.g));
1607 TEST(ConvertNCHWToNHWC, Mul)
1612 run_phase(&g.g, false, false);
1614 auto input_succs = loco::succs(g.input);
1615 EXPECT_EQ(1, input_succs.size());
1616 check_post_trans(*input_succs.begin());
1618 check_pre_trans(g.mul->x());
1620 auto mul_succs = loco::succs(g.mul);
1621 EXPECT_EQ(1, mul_succs.size());
1622 check_post_trans(*mul_succs.begin());
1624 uint32_t channel_size = 16;
1625 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1626 EXPECT_NE(nullptr, new_multiplier);
1627 EXPECT_EQ(4, new_multiplier->rank());
1628 EXPECT_EQ(1, new_multiplier->dim(0).value());
1629 EXPECT_EQ(1, new_multiplier->dim(1).value());
1630 EXPECT_EQ(1, new_multiplier->dim(2).value());
1631 EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1633 check_pre_trans(g.output->from());
1636 TEST(ConvertNCHWToNHWC, Mul_NCHW_const)
1640 g.update_const_shape_to_nchw();
1642 run_phase(&g.g, false, false);
1644 check_pre_trans(g.mul->x());
1646 auto mul_succs = loco::succs(g.mul);
1647 EXPECT_EQ(1, mul_succs.size());
1648 check_post_trans(*mul_succs.begin());
1650 uint32_t channel_size = 16;
1651 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1652 EXPECT_NE(nullptr, new_multiplier);
1653 EXPECT_EQ(4, new_multiplier->rank());
1654 EXPECT_EQ(1, new_multiplier->dim(0).value());
1655 EXPECT_EQ(4, new_multiplier->dim(1).value());
1656 EXPECT_EQ(4, new_multiplier->dim(2).value());
1657 EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1660 TEST(ConvertNCHWToNHWC, MulScalar)
1665 run_phase(&g.g, false, false);
1667 auto input_succs = loco::succs(g.input);
1668 EXPECT_EQ(1, input_succs.size());
1669 check_post_trans(*input_succs.begin());
1671 check_pre_trans(g.mul->x());
1673 auto mul_succs = loco::succs(g.mul);
1674 EXPECT_EQ(1, mul_succs.size());
1675 check_post_trans(*mul_succs.begin());
1677 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1678 EXPECT_NE(nullptr, new_multiplier);
1679 EXPECT_EQ(4, new_multiplier->rank());
1680 EXPECT_EQ(1, new_multiplier->dim(0).value());
1681 EXPECT_EQ(1, new_multiplier->dim(1).value());
1682 EXPECT_EQ(1, new_multiplier->dim(2).value());
1683 EXPECT_EQ(1, new_multiplier->dim(3).value());
1685 check_pre_trans(g.output->from());
1688 TEST(ConvertNCHWToNHWC, MulBothNorm)
1693 run_phase(&g.g, false, false);
1695 auto input_succs = loco::succs(g.input);
1696 EXPECT_EQ(1, input_succs.size());
1697 check_post_trans(*input_succs.begin());
1699 check_pre_trans(g.mul->x());
1700 check_pre_trans(g.mul->y());
1702 auto mul_succs = loco::succs(g.mul);
1703 EXPECT_EQ(1, mul_succs.size());
1704 check_post_trans(*mul_succs.begin());
1706 check_pre_trans(g.output->from());
1709 TEST(ConvertNCHWToNHWC, Neg)
1714 run_phase(&g.g, true, true);
1716 check_pre_trans(g.neg->x());
1718 auto neg_succs = loco::succs(g.neg);
1719 EXPECT_EQ(1, neg_succs.size());
1720 check_post_trans(*neg_succs.begin());
1722 // Check leakyrelu shape
1723 EXPECT_EQ(1, g.neg->dim(0).value());
1724 EXPECT_EQ(4, g.neg->dim(1).value());
1725 EXPECT_EQ(4, g.neg->dim(2).value());
1726 EXPECT_EQ(16, g.neg->dim(3).value());
1729 TEST(ConvertNCHWToNHWC, Pad)
1734 run_phase(&g.g, false, false);
1736 auto input_succs = loco::succs(g.input);
1737 EXPECT_EQ(1, input_succs.size());
1738 check_post_trans(*input_succs.begin());
1740 check_pre_trans(g.pad->input());
1742 auto pad_succs = loco::succs(g.pad);
1743 EXPECT_EQ(1, pad_succs.size());
1744 check_post_trans(*pad_succs.begin());
1746 auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1747 EXPECT_NE(nullptr, new_paddings);
1748 EXPECT_EQ(2, new_paddings->rank());
1749 EXPECT_EQ(4, new_paddings->dim(0).value());
1750 EXPECT_EQ(2, new_paddings->dim(1).value());
1751 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1752 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1753 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1754 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1755 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1756 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1757 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1758 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1760 check_pre_trans(g.output->from());
1763 TEST(ConvertNCHWToNHWC, PadV2)
1768 run_phase(&g.g, false, false);
1770 check_pre_trans(g.pad->input());
1772 auto pad_succs = loco::succs(g.pad);
1773 EXPECT_EQ(1, pad_succs.size());
1774 check_post_trans(*pad_succs.begin());
1776 auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1777 EXPECT_NE(nullptr, new_paddings);
1778 EXPECT_EQ(2, new_paddings->rank());
1779 EXPECT_EQ(4, new_paddings->dim(0).value());
1780 EXPECT_EQ(2, new_paddings->dim(1).value());
1781 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1782 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1783 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1784 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1785 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1786 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1787 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1788 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1791 TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
1797 g.input->dim(0).unset();
1798 g.add->dim(0).unset();
1799 g.output->dim(0).unset();
1801 luci::ConvertNCHWToNHWCPass pass(false, false);
1802 EXPECT_EQ(false, pass.run(&g.g));
1805 TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
1812 run_phase(&g.g, true, false);
1814 // Check input shape
1815 EXPECT_EQ(1, g.input->dim(0).value());
1816 EXPECT_EQ(16, g.input->dim(1).value());
1817 EXPECT_EQ(4, g.input->dim(2).value());
1818 EXPECT_EQ(4, g.input->dim(3).value());
1820 // Check output shape
1821 EXPECT_EQ(1, g.output->dim(0).value());
1822 EXPECT_EQ(4, g.output->dim(1).value());
1823 EXPECT_EQ(4, g.output->dim(2).value());
1824 EXPECT_EQ(16, g.output->dim(3).value());
1832 run_phase(&g.g, false, true);
1834 // Check input shape
1835 EXPECT_EQ(1, g.input->dim(0).value());
1836 EXPECT_EQ(4, g.input->dim(1).value());
1837 EXPECT_EQ(4, g.input->dim(2).value());
1838 EXPECT_EQ(16, g.input->dim(3).value());
1840 // Check output shape
1841 EXPECT_EQ(1, g.output->dim(0).value());
1842 EXPECT_EQ(16, g.output->dim(1).value());
1843 EXPECT_EQ(4, g.output->dim(2).value());
1844 EXPECT_EQ(4, g.output->dim(3).value());
1847 // Preserve both input and output
1852 run_phase(&g.g, true, true);
1854 // Check input shape
1855 EXPECT_EQ(1, g.input->dim(0).value());
1856 EXPECT_EQ(16, g.input->dim(1).value());
1857 EXPECT_EQ(4, g.input->dim(2).value());
1858 EXPECT_EQ(4, g.input->dim(3).value());
1860 // Check output shape
1861 EXPECT_EQ(1, g.output->dim(0).value());
1862 EXPECT_EQ(16, g.output->dim(1).value());
1863 EXPECT_EQ(4, g.output->dim(2).value());
1864 EXPECT_EQ(4, g.output->dim(3).value());
1868 TEST(ConvertNCHWToNHWC, ReduceMax)
1873 run_phase(&g.g, false, false);
1875 check_pre_trans(g.rm->input());
1877 auto rm_succs = loco::succs(g.rm);
1878 EXPECT_EQ(1, rm_succs.size());
1879 check_post_trans(*rm_succs.begin());
1881 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1882 EXPECT_NE(nullptr, new_rindices);
1883 EXPECT_EQ(1, new_rindices->rank());
1884 EXPECT_EQ(2, new_rindices->dim(0).value());
1885 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1886 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1887 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1890 TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false)
1894 std::vector<int32_t> nchw_ind;
1895 std::vector<int32_t> nhwc_ind;
1896 std::initializer_list<uint32_t> shape;
1897 bool needs_transpose = false;
1905 std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
1906 {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
1907 {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
1908 {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
1909 {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
1910 {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1912 for (auto &tc : test_cases)
1916 g.axes(tc.nchw_ind);
1920 run_phase(&g.g, true, true);
1922 check_pre_trans(g.rm->input());
1924 auto rm_succs = loco::succs(g.rm);
1925 EXPECT_EQ(1, rm_succs.size());
1926 if (tc.needs_transpose)
1928 EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
1932 EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
1935 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1936 EXPECT_NE(nullptr, new_rindices);
1937 EXPECT_EQ(1, new_rindices->rank());
1938 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1939 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1940 for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1942 EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1947 TEST(ConvertNCHWToNHWC, ReduceMin)
1952 run_phase(&g.g, true, true);
1954 check_pre_trans(g.rm->input());
1956 auto rm_succs = loco::succs(g.rm);
1957 EXPECT_EQ(1, rm_succs.size());
1958 check_post_trans(*rm_succs.begin());
1960 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1961 EXPECT_NE(nullptr, new_rindices);
1962 EXPECT_EQ(1, new_rindices->rank());
1963 EXPECT_EQ(2, new_rindices->dim(0).value());
1964 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1965 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1966 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1969 TEST(ConvertNCHWToNHWC, ReduceMin_keep_dims_false)
1973 std::vector<int32_t> nchw_ind;
1974 std::vector<int32_t> nhwc_ind;
1975 std::initializer_list<uint32_t> shape;
1976 bool needs_transpose = false;
1984 std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
1985 {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
1986 {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
1987 {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
1988 {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
1989 {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1991 for (auto &tc : test_cases)
1995 g.axes(tc.nchw_ind);
1999 run_phase(&g.g, true, true);
2001 check_pre_trans(g.rm->input());
2003 auto rm_succs = loco::succs(g.rm);
2004 EXPECT_EQ(1, rm_succs.size());
2005 if (tc.needs_transpose)
2007 EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
2011 EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
2014 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
2015 EXPECT_NE(nullptr, new_rindices);
2016 EXPECT_EQ(1, new_rindices->rank());
2017 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
2018 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
2019 for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
2021 EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
2026 TEST(ConvertNCHWToNHWC, Relu)
2031 run_phase(&g.g, true, true);
2033 check_pre_trans(g.relu->features());
2035 auto relu_succs = loco::succs(g.relu);
2036 EXPECT_EQ(1, relu_succs.size());
2037 check_post_trans(*relu_succs.begin());
2040 EXPECT_EQ(1, g.relu->dim(0).value());
2041 EXPECT_EQ(4, g.relu->dim(1).value());
2042 EXPECT_EQ(4, g.relu->dim(2).value());
2043 EXPECT_EQ(16, g.relu->dim(3).value());
2046 TEST(ConvertNCHWToNHWC, Relu6)
2051 run_phase(&g.g, true, true);
2053 check_pre_trans(g.relu6->features());
2055 auto relu6_succs = loco::succs(g.relu6);
2056 EXPECT_EQ(1, relu6_succs.size());
2057 check_post_trans(*relu6_succs.begin());
2059 // Check relu6 shape
2060 EXPECT_EQ(1, g.relu6->dim(0).value());
2061 EXPECT_EQ(4, g.relu6->dim(1).value());
2062 EXPECT_EQ(4, g.relu6->dim(2).value());
2063 EXPECT_EQ(16, g.relu6->dim(3).value());
2066 TEST(ConvertNCHWToNHWC, Rsqrt)
2071 run_phase(&g.g, true, true);
2073 check_pre_trans(g.rsqrt->x());
2075 auto rsqrt_succs = loco::succs(g.rsqrt);
2076 EXPECT_EQ(1, rsqrt_succs.size());
2077 check_post_trans(*rsqrt_succs.begin());
2079 // Check rsqrt shape
2080 EXPECT_EQ(1, g.rsqrt->dim(0).value());
2081 EXPECT_EQ(4, g.rsqrt->dim(1).value());
2082 EXPECT_EQ(4, g.rsqrt->dim(2).value());
2083 EXPECT_EQ(16, g.rsqrt->dim(3).value());
2086 TEST(ConvertNCHWToNHWC, SplitV)
2091 run_phase(g.g(), true, true);
2093 check_pre_trans(g.splitv()->input());
2095 auto splitv_succs = loco::succs(g.splitv());
2096 for (auto svo : loco::succs(g.splitv()))
2098 for (auto succ : loco::succs(svo))
2100 check_post_trans(succ);
2104 // Check splitv() shape
2105 EXPECT_EQ(1, g.splitv()->dim(0).value());
2106 EXPECT_EQ(2, g.splitv()->dim(1).value());
2107 EXPECT_EQ(192, g.splitv()->dim(2).value());
2108 EXPECT_EQ(2, g.splitv()->dim(3).value());
2111 auto axis = dynamic_cast<luci::CircleConst *>(g.splitv()->split_dim());
2112 EXPECT_NE(nullptr, axis);
2113 EXPECT_EQ(1, axis->size<loco::DataType::S32>());
2114 EXPECT_EQ(2, axis->at<loco::DataType::S32>(0));
2117 TEST(ConvertNCHWToNHWC, SquaredDifference)
2119 SquaredDifferenceGraph g;
2122 run_phase(&g.g, true, true);
2124 check_pre_trans(g.sqdiff->x());
2125 check_pre_trans(g.sqdiff->y());
2127 auto sqdiff_succs = loco::succs(g.sqdiff);
2128 EXPECT_EQ(1, sqdiff_succs.size());
2129 check_post_trans(*sqdiff_succs.begin());
2132 TEST(ConvertNCHWToNHWC, Sub)
2137 run_phase(&g.g, false, false);
2139 auto input_succs = loco::succs(g.input);
2140 EXPECT_EQ(1, input_succs.size());
2141 check_post_trans(*input_succs.begin());
2143 check_pre_trans(g.sub->x());
2145 auto add_succs = loco::succs(g.sub);
2146 EXPECT_EQ(1, add_succs.size());
2147 check_post_trans(*add_succs.begin());
2149 uint32_t channel_size = 16;
2150 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2151 EXPECT_NE(nullptr, new_beta);
2152 EXPECT_EQ(4, new_beta->rank());
2153 EXPECT_EQ(1, new_beta->dim(0).value());
2154 EXPECT_EQ(1, new_beta->dim(1).value());
2155 EXPECT_EQ(1, new_beta->dim(2).value());
2156 EXPECT_EQ(channel_size, new_beta->dim(3).value());
2158 check_pre_trans(g.output->from());
2161 TEST(ConvertNCHWToNHWC, Sub_NCHW_const)
2165 g.update_const_shape_to_nchw();
2167 run_phase(&g.g, false, false);
2169 check_pre_trans(g.sub->x());
2171 auto sub_succs = loco::succs(g.sub);
2172 EXPECT_EQ(1, sub_succs.size());
2173 check_post_trans(*sub_succs.begin());
2175 uint32_t channel_size = 16;
2176 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2177 EXPECT_NE(nullptr, new_beta);
2178 EXPECT_EQ(4, new_beta->rank());
2179 EXPECT_EQ(1, new_beta->dim(0).value());
2180 EXPECT_EQ(4, new_beta->dim(1).value());
2181 EXPECT_EQ(4, new_beta->dim(2).value());
2182 EXPECT_EQ(channel_size, new_beta->dim(3).value());
2185 TEST(ConvertNCHWToNHWC, SubScalar)
2190 run_phase(&g.g, false, false);
2192 auto input_succs = loco::succs(g.input);
2193 EXPECT_EQ(1, input_succs.size());
2194 check_post_trans(*input_succs.begin());
2196 check_pre_trans(g.sub->y());
2198 auto add_succs = loco::succs(g.sub);
2199 EXPECT_EQ(1, add_succs.size());
2200 check_post_trans(*add_succs.begin());
2202 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
2203 EXPECT_NE(nullptr, new_beta);
2204 EXPECT_EQ(1, new_beta->rank());
2206 check_pre_trans(g.output->from());
2209 TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG)
2211 NoPostReshapeGraph g;
2214 run_phase(&g.g, true, true);
2216 check_pre_trans(g.relu->features());
2218 auto relu_succs = loco::succs(g.relu);
2219 EXPECT_EQ(1, relu_succs.size());
2220 check_post_trans(*relu_succs.begin());
2223 TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG)
2225 ReluNotClosedGraph g;
2228 run_phase(&g.g, true, true);
2230 check_pre_trans(g.relu->features());
2232 auto relu_succs = loco::succs(g.relu);
2233 EXPECT_EQ(1, relu_succs.size());
2234 check_post_trans(*relu_succs.begin());