2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <logo/Phase.h>
19 #include <luci/test/TestIOGraph.h>
21 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
22 #include "luci/Pass/CircleShapeInferencePass.h"
24 #include <luci/IR/CircleNodes.h>
26 #include <gtest/gtest.h>
28 using namespace luci::test;
34 * Graph with a single Op (example: Add).
37 * - All Ops including Input/Output are NCHW.
46 * - All Ops including Input/Output are NHWC.
65 SimpleGraph() = default;
70 input = g.nodes()->create<luci::CircleInput>();
71 output = g.nodes()->create<luci::CircleOutput>();
73 output->name("output");
75 auto graph_input = g.inputs()->create();
76 input->index(graph_input->index());
77 auto graph_output = g.outputs()->create();
78 output->index(graph_output->index());
80 graph_input->dtype(loco::DataType::FLOAT32);
81 input->dtype(loco::DataType::FLOAT32);
82 output->dtype(loco::DataType::FLOAT32);
83 graph_output->dtype(loco::DataType::FLOAT32);
85 uint32_t channel_size = 16;
86 graph_input->shape({1, channel_size, 4, 4});
87 input->shape({1, channel_size, 4, 4});
88 output->shape({1, channel_size, 4, 4});
89 graph_output->shape({1, channel_size, 4, 4});
91 auto graph_body = insertGraphBody(input);
92 output->from(graph_body);
95 virtual ~SimpleGraph() = default;
98 virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
102 luci::CircleInput *input = nullptr;
103 luci::CircleOutput *output = nullptr;
106 class AddGraph final : public SimpleGraph
109 loco::Node *insertGraphBody(loco::Node *input) override
111 add = g.nodes()->create<luci::CircleAdd>();
112 beta = g.nodes()->create<luci::CircleConst>();
114 add->dtype(loco::DataType::FLOAT32);
115 beta->dtype(loco::DataType::FLOAT32);
117 uint32_t channel_size = 16;
118 add->shape({1, channel_size, 4, 4});
119 beta->shape({1, channel_size, 1, 1});
121 beta->size<loco::DataType::FLOAT32>(channel_size);
122 for (uint32_t i = 0; i < channel_size; i++)
124 beta->at<loco::DataType::FLOAT32>(i) = i;
137 void update_const_shape_to_nchw(void)
139 uint32_t channel_size = 16;
140 beta->shape({1, channel_size, 4, 4});
142 beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
143 for (uint32_t i = 0; i < channel_size; i++)
145 beta->at<loco::DataType::FLOAT32>(i) = i;
150 luci::CircleAdd *add = nullptr;
151 luci::CircleConst *beta = nullptr;
154 class NHWCReluGraph final : public SimpleGraph
157 loco::Node *insertGraphBody(loco::Node *input) override
159 relu = g.nodes()->create<luci::CircleRelu>();
160 pre_reshape = g.nodes()->create<luci::CircleReshape>();
161 post_reshape = g.nodes()->create<luci::CircleReshape>();
162 pre_shape = g.nodes()->create<luci::CircleConst>();
163 post_shape = g.nodes()->create<luci::CircleConst>();
165 pre_shape->dtype(loco::DataType::S32);
166 post_shape->dtype(loco::DataType::S32);
168 uint32_t channel_size = 16;
169 auto in = loco::must_cast<luci::CircleNode *>(input);
170 in->shape({1, channel_size, 4, 4});
171 pre_shape->shape({4});
172 post_shape->shape({4});
174 pre_shape->size<loco::DataType::S32>(4);
175 pre_shape->at<loco::DataType::S32>(0) = 1;
176 pre_shape->at<loco::DataType::S32>(1) = 4;
177 pre_shape->at<loco::DataType::S32>(2) = 4;
178 pre_shape->at<loco::DataType::S32>(3) = channel_size;
180 post_shape->size<loco::DataType::S32>(4);
181 post_shape->at<loco::DataType::S32>(0) = 1;
182 post_shape->at<loco::DataType::S32>(1) = channel_size;
183 post_shape->at<loco::DataType::S32>(2) = 4;
184 post_shape->at<loco::DataType::S32>(3) = 4;
186 pre_reshape->tensor(input);
187 pre_reshape->shape(pre_shape);
189 relu->features(pre_reshape);
191 post_reshape->tensor(relu);
192 post_reshape->shape(post_shape);
195 pre_reshape->name("pre-reshape");
196 post_reshape->name("post-reshape");
202 luci::CircleRelu *relu = nullptr;
203 luci::CircleReshape *pre_reshape = nullptr;
204 luci::CircleReshape *post_reshape = nullptr;
205 luci::CircleConst *pre_shape = nullptr;
206 luci::CircleConst *post_shape = nullptr;
210 * Graph with pre-Reshape but no post-Transpose/Reshape.
234 class NoPostReshapeGraph final : public SimpleGraph
237 loco::Node *insertGraphBody(loco::Node *input) override
239 relu = g.nodes()->create<luci::CircleRelu>();
240 pre_reshape = g.nodes()->create<luci::CircleReshape>();
241 pre_shape = g.nodes()->create<luci::CircleConst>();
243 pre_shape->dtype(loco::DataType::S32);
245 uint32_t channel_size = 16;
246 auto in = loco::must_cast<luci::CircleNode *>(input);
247 in->shape({1, channel_size, 4, 4});
248 pre_shape->shape({4});
250 pre_shape->size<loco::DataType::S32>(4);
251 pre_shape->at<loco::DataType::S32>(0) = 1;
252 pre_shape->at<loco::DataType::S32>(1) = 4;
253 pre_shape->at<loco::DataType::S32>(2) = 4;
254 pre_shape->at<loco::DataType::S32>(3) = channel_size;
256 pre_reshape->tensor(input);
257 pre_reshape->shape(pre_shape);
258 relu->features(pre_reshape);
261 pre_reshape->name("pre-reshape");
267 luci::CircleRelu *relu = nullptr;
268 luci::CircleReshape *pre_reshape = nullptr;
269 luci::CircleConst *pre_shape = nullptr;
273 * Graph with two pre-Reshapes
305 class ReluNotClosedGraph final : public SimpleGraph
308 loco::Node *insertGraphBody(loco::Node *input) override
310 relu = g.nodes()->create<luci::CircleRelu>();
311 pre_reshape = g.nodes()->create<luci::CircleReshape>();
312 pre_reshape_2 = g.nodes()->create<luci::CircleReshape>();
313 post_reshape = g.nodes()->create<luci::CircleReshape>();
314 pre_shape = g.nodes()->create<luci::CircleConst>();
315 pre_shape_2 = g.nodes()->create<luci::CircleConst>();
316 post_shape = g.nodes()->create<luci::CircleConst>();
318 pre_shape->dtype(loco::DataType::S32);
319 pre_shape_2->dtype(loco::DataType::S32);
320 post_shape->dtype(loco::DataType::S32);
322 uint32_t channel_size = 16;
323 auto in = loco::must_cast<luci::CircleNode *>(input);
324 in->shape({1, channel_size, 4, 4});
325 pre_shape->shape({4});
326 pre_shape_2->shape({4});
327 post_shape->shape({4});
329 pre_shape->size<loco::DataType::S32>(4);
330 pre_shape->at<loco::DataType::S32>(0) = 1;
331 pre_shape->at<loco::DataType::S32>(1) = 4;
332 pre_shape->at<loco::DataType::S32>(2) = 4;
333 pre_shape->at<loco::DataType::S32>(3) = channel_size;
335 pre_shape_2->size<loco::DataType::S32>(4);
336 pre_shape_2->at<loco::DataType::S32>(0) = 1;
337 pre_shape_2->at<loco::DataType::S32>(1) = 4;
338 pre_shape_2->at<loco::DataType::S32>(2) = channel_size;
339 pre_shape_2->at<loco::DataType::S32>(3) = 4;
341 post_shape->size<loco::DataType::S32>(4);
342 post_shape->at<loco::DataType::S32>(0) = 1;
343 post_shape->at<loco::DataType::S32>(1) = 4;
344 post_shape->at<loco::DataType::S32>(2) = 4;
345 post_shape->at<loco::DataType::S32>(3) = channel_size;
347 pre_reshape->tensor(input);
348 pre_reshape->shape(pre_shape);
350 relu->features(pre_reshape);
352 pre_reshape_2->tensor(relu);
353 pre_reshape_2->shape(pre_shape_2);
355 post_reshape->tensor(pre_reshape_2);
356 post_reshape->shape(post_shape);
359 pre_reshape->name("pre-reshape");
360 pre_reshape->name("pre-reshape-2");
361 post_reshape->name("post-reshape");
367 luci::CircleRelu *relu = nullptr;
368 luci::CircleReshape *pre_reshape = nullptr;
369 luci::CircleReshape *pre_reshape_2 = nullptr;
370 luci::CircleReshape *post_reshape = nullptr;
371 luci::CircleConst *pre_shape = nullptr;
372 luci::CircleConst *pre_shape_2 = nullptr;
373 luci::CircleConst *post_shape = nullptr;
376 class AddScalarGraph final : public SimpleGraph
379 loco::Node *insertGraphBody(loco::Node *input) override
381 add = g.nodes()->create<luci::CircleAdd>();
382 beta = g.nodes()->create<luci::CircleConst>();
384 add->dtype(loco::DataType::FLOAT32);
385 beta->dtype(loco::DataType::FLOAT32);
387 uint32_t channel_size = 16;
388 add->shape({1, channel_size, 4, 4});
391 beta->size<loco::DataType::FLOAT32>(1);
392 beta->at<loco::DataType::FLOAT32>(0) = 3.14;
404 luci::CircleAdd *add = nullptr;
405 luci::CircleConst *beta = nullptr;
408 class ConcatenationGraph final : public SimpleGraph
411 loco::Node *insertGraphBody(loco::Node *input) override
413 concat = g.nodes()->create<luci::CircleConcatenation>(2);
414 concat->values(0, input);
417 input2 = g.nodes()->create<luci::CircleConst>();
418 input2->dtype(loco::DataType::FLOAT32);
419 input2->shape({1, 16, 4, 4});
420 input2->size<loco::DataType::FLOAT32>(16 * 4 * 4);
421 for (uint32_t i = 0; i < 16 * 4 * 4; i++)
423 input2->at<loco::DataType::FLOAT32>(i) = i;
425 concat->values(1, input2);
427 concat->name("concat");
428 input2->name("input2");
434 luci::CircleConcatenation *concat = nullptr;
435 luci::CircleConst *input2 = nullptr;
438 class EluGraph final : public SimpleGraph
441 loco::Node *insertGraphBody(loco::Node *input) override
443 elu = g.nodes()->create<luci::CircleElu>();
444 elu->features(input);
451 luci::CircleElu *elu = nullptr;
454 class LeakyReluGraph final : public SimpleGraph
457 loco::Node *insertGraphBody(loco::Node *input) override
459 leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>();
460 leakyrelu->features(input);
461 leakyrelu->name("leakyrelu");
467 luci::CircleLeakyRelu *leakyrelu = nullptr;
470 class LogisticGraph final : public SimpleGraph
473 loco::Node *insertGraphBody(loco::Node *input) override
475 logistic = g.nodes()->create<luci::CircleLogistic>();
477 logistic->name("logistic");
483 luci::CircleLogistic *logistic = nullptr;
486 class MaximumGraph final : public SimpleGraph
489 loco::Node *insertGraphBody(loco::Node *input) override
491 max = g.nodes()->create<luci::CircleMaximum>();
492 limit = g.nodes()->create<luci::CircleConst>();
494 max->dtype(loco::DataType::FLOAT32);
495 limit->dtype(loco::DataType::FLOAT32);
497 max->shape({1, 16, 4, 4});
500 limit->size<loco::DataType::FLOAT32>(1);
501 limit->at<loco::DataType::FLOAT32>(0) = 100;
507 limit->name("limit");
513 luci::CircleMaximum *max = nullptr;
514 luci::CircleConst *limit = nullptr;
517 class MaximumNonConstGraph final : public SimpleGraph
520 loco::Node *insertGraphBody(loco::Node *input) override
522 max = g.nodes()->create<luci::CircleMaximum>();
523 max->dtype(loco::DataType::FLOAT32);
524 max->shape({1, 16, 4, 4});
535 luci::CircleMaximum *max = nullptr;
538 static constexpr std::initializer_list<uint32_t> kDefaultShape = {1, 16, 1, 1};
540 class MeanGraph final : public SimpleGraph
543 loco::Node *insertGraphBody(loco::Node *input) override
545 mean = g.nodes()->create<luci::CircleMean>();
546 rindices = g.nodes()->create<luci::CircleConst>();
548 mean->dtype(loco::DataType::FLOAT32);
549 rindices->dtype(loco::DataType::S32);
552 rindices->shape({static_cast<uint32_t>(_axes.size())});
554 rindices->size<loco::DataType::S32>(_axes.size());
555 for (uint32_t i = 0; i < _axes.size(); ++i)
557 rindices->at<loco::DataType::S32>(i) = _axes[i];
561 mean->reduction_indices(rindices);
562 mean->keep_dims(_keep_dims);
565 rindices->name("rindices");
571 void keep_dims(bool val) { _keep_dims = val; }
572 void axes(std::vector<int32_t> val) { _axes = val; }
573 void shape(std::initializer_list<uint32_t> val) { _shape = val; }
576 luci::CircleMean *mean = nullptr;
577 luci::CircleConst *rindices = nullptr;
580 bool _keep_dims = true;
581 std::vector<int32_t> _axes = {2, 3};
582 std::initializer_list<uint32_t> _shape = kDefaultShape;
585 class MinimumGraph final : public SimpleGraph
588 loco::Node *insertGraphBody(loco::Node *input) override
590 min = g.nodes()->create<luci::CircleMinimum>();
591 limit = g.nodes()->create<luci::CircleConst>();
593 min->dtype(loco::DataType::FLOAT32);
594 limit->dtype(loco::DataType::FLOAT32);
596 min->shape({1, 16, 4, 4});
599 limit->size<loco::DataType::FLOAT32>(1);
600 limit->at<loco::DataType::FLOAT32>(0) = 100;
606 limit->name("limit");
612 luci::CircleMinimum *min = nullptr;
613 luci::CircleConst *limit = nullptr;
616 class MulGraph final : public SimpleGraph
619 loco::Node *insertGraphBody(loco::Node *input) override
621 mul = g.nodes()->create<luci::CircleMul>();
622 multiplier = g.nodes()->create<luci::CircleConst>();
624 mul->dtype(loco::DataType::FLOAT32);
625 multiplier->dtype(loco::DataType::FLOAT32);
627 uint32_t channel_size = 16;
628 mul->shape({1, channel_size, 4, 4});
629 multiplier->shape({1, channel_size, 1, 1});
631 multiplier->size<loco::DataType::FLOAT32>(channel_size);
632 for (uint32_t i = 0; i < channel_size; i++)
634 multiplier->at<loco::DataType::FLOAT32>(i) = i;
641 multiplier->name("multiplier");
647 void update_const_shape_to_nchw(void)
649 uint32_t channel_size = 16;
650 multiplier->shape({1, channel_size, 4, 4});
652 multiplier->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
653 for (uint32_t i = 0; i < channel_size; i++)
655 multiplier->at<loco::DataType::FLOAT32>(i) = i;
660 luci::CircleMul *mul = nullptr;
661 luci::CircleConst *multiplier = nullptr;
664 class MulScalarGraph final : public SimpleGraph
667 loco::Node *insertGraphBody(loco::Node *input) override
669 mul = g.nodes()->create<luci::CircleMul>();
670 multiplier = g.nodes()->create<luci::CircleConst>();
672 mul->dtype(loco::DataType::FLOAT32);
673 multiplier->dtype(loco::DataType::FLOAT32);
675 uint32_t channel_size = 16;
676 mul->shape({1, channel_size, 4, 4});
677 multiplier->shape({1});
679 multiplier->size<loco::DataType::FLOAT32>(1);
680 multiplier->at<loco::DataType::FLOAT32>(0) = 2;
686 multiplier->name("multiplier");
692 luci::CircleMul *mul = nullptr;
693 luci::CircleConst *multiplier = nullptr;
696 class MulBothNormGraph final : public SimpleGraph
699 loco::Node *insertGraphBody(loco::Node *input) override
701 mul = g.nodes()->create<luci::CircleMul>();
703 mul->dtype(loco::DataType::FLOAT32);
705 uint32_t channel_size = 16;
706 mul->shape({1, channel_size, 4, 4});
717 luci::CircleMul *mul = nullptr;
720 class NegGraph final : public SimpleGraph
723 loco::Node *insertGraphBody(loco::Node *input) override
725 neg = g.nodes()->create<luci::CircleNeg>();
733 luci::CircleNeg *neg = nullptr;
736 class PadGraph final : public SimpleGraph
739 loco::Node *insertGraphBody(loco::Node *input) override
741 pad = g.nodes()->create<luci::CirclePad>();
742 paddings = g.nodes()->create<luci::CircleConst>();
744 pad->dtype(loco::DataType::FLOAT32);
745 paddings->dtype(loco::DataType::S32);
747 uint32_t channel_size = 16;
748 pad->shape({1, channel_size, 4, 4});
749 paddings->shape({4, 2});
751 // paddings data (NCHW)
752 // [[0,0], [0,0], [1,1], [2,2]]
753 paddings->size<loco::DataType::S32>(8);
754 for (uint32_t dim = 0; dim < 4; dim++)
756 for (uint32_t i = 0; i < 2; i++)
765 paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
770 pad->paddings(paddings);
773 paddings->name("paddings");
779 luci::CirclePad *pad = nullptr;
780 luci::CircleConst *paddings = nullptr;
783 class PadV2Graph final : public SimpleGraph
786 loco::Node *insertGraphBody(loco::Node *input) override
788 pad = g.nodes()->create<luci::CirclePadV2>();
789 paddings = g.nodes()->create<luci::CircleConst>();
790 const_value = g.nodes()->create<luci::CircleConst>();
792 pad->dtype(loco::DataType::FLOAT32);
793 paddings->dtype(loco::DataType::S32);
794 const_value->dtype(loco::DataType::FLOAT32);
796 uint32_t channel_size = 16;
797 pad->shape({1, channel_size, 4, 4});
798 paddings->shape({4, 2});
799 const_value->shape({1});
801 // paddings data (NCHW)
802 // [[0,0], [0,0], [1,1], [2,2]]
803 paddings->size<loco::DataType::S32>(8);
804 for (uint32_t dim = 0; dim < 4; dim++)
806 for (uint32_t i = 0; i < 2; i++)
815 paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
819 const_value->size<loco::DataType::FLOAT32>(1);
820 const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
823 pad->paddings(paddings);
824 pad->constant_values(paddings);
827 paddings->name("paddings");
828 const_value->name("constant_values");
834 luci::CirclePadV2 *pad = nullptr;
835 luci::CircleConst *paddings = nullptr;
836 luci::CircleConst *const_value = nullptr;
839 class ReduceMaxGraph final : public SimpleGraph
842 loco::Node *insertGraphBody(loco::Node *input) override
844 rm = g.nodes()->create<luci::CircleReduceMax>();
845 rindices = g.nodes()->create<luci::CircleConst>();
847 rm->dtype(loco::DataType::FLOAT32);
848 rindices->dtype(loco::DataType::S32);
851 rindices->shape({static_cast<uint32_t>(_axes.size())});
853 rindices->size<loco::DataType::S32>(_axes.size());
854 for (uint32_t i = 0; i < _axes.size(); ++i)
856 rindices->at<loco::DataType::S32>(i) = _axes[i];
860 rm->reduction_indices(rindices);
861 rm->keep_dims(_keep_dims);
863 rm->name("reduce_max");
864 rindices->name("rindices");
870 void keep_dims(bool val) { _keep_dims = val; }
871 void axes(std::vector<int32_t> val) { _axes = val; }
872 void shape(std::initializer_list<uint32_t> val) { _shape = val; }
875 luci::CircleReduceMax *rm = nullptr;
876 luci::CircleConst *rindices = nullptr;
879 bool _keep_dims = true;
880 std::vector<int32_t> _axes = {2, 3};
881 std::initializer_list<uint32_t> _shape = kDefaultShape;
884 class ReduceMinGraph final : public SimpleGraph
887 loco::Node *insertGraphBody(loco::Node *input) override
889 rm = g.nodes()->create<luci::CircleReduceMin>();
890 rindices = g.nodes()->create<luci::CircleConst>();
892 rm->dtype(loco::DataType::FLOAT32);
893 rindices->dtype(loco::DataType::S32);
896 rindices->shape({static_cast<uint32_t>(_axes.size())});
898 rindices->size<loco::DataType::S32>(_axes.size());
899 for (uint32_t i = 0; i < _axes.size(); ++i)
901 rindices->at<loco::DataType::S32>(i) = _axes[i];
905 rm->reduction_indices(rindices);
906 rm->keep_dims(_keep_dims);
908 rm->name("reduce_max");
909 rindices->name("rindices");
915 void keep_dims(bool val) { _keep_dims = val; }
916 void axes(std::vector<int32_t> val) { _axes = val; }
917 void shape(std::initializer_list<uint32_t> val) { _shape = val; }
920 luci::CircleReduceMin *rm = nullptr;
921 luci::CircleConst *rindices = nullptr;
924 bool _keep_dims = true;
925 std::vector<int32_t> _axes = {2, 3};
926 std::initializer_list<uint32_t> _shape = kDefaultShape;
929 class ReluGraph final : public SimpleGraph
932 loco::Node *insertGraphBody(loco::Node *input) override
934 relu = g.nodes()->create<luci::CircleRelu>();
935 relu->features(input);
942 luci::CircleRelu *relu = nullptr;
945 class Relu6Graph final : public SimpleGraph
948 loco::Node *insertGraphBody(loco::Node *input) override
950 relu6 = g.nodes()->create<luci::CircleRelu6>();
951 relu6->features(input);
952 relu6->name("relu6");
958 luci::CircleRelu6 *relu6 = nullptr;
961 class RsqrtGraph final : public SimpleGraph
964 loco::Node *insertGraphBody(loco::Node *input) override
966 rsqrt = g.nodes()->create<luci::CircleRsqrt>();
968 rsqrt->name("rsqrt");
974 luci::CircleRsqrt *rsqrt = nullptr;
980 SplitVGraphlet() = default;
983 void init(loco::Graph *g)
985 // CircleCustom(SplitV)
986 _splitv = g->nodes()->create<luci::CircleSplitV>();
987 _splitv->shape({1, 2, 2, 192});
988 _splitv->dtype(loco::DataType::FLOAT32);
989 _splitv->name("splitv");
992 auto size_splits = g->nodes()->create<luci::CircleConst>();
993 size_splits->dtype(loco::DataType::S32);
994 size_splits->shape({3});
995 size_splits->size<loco::DataType::S32>(3);
996 size_splits->at<loco::DataType::S32>(0) = 32;
997 size_splits->at<loco::DataType::S32>(1) = 32;
998 size_splits->at<loco::DataType::S32>(2) = 128;
1001 auto split_dim = g->nodes()->create<luci::CircleConst>();
1002 split_dim->dtype(loco::DataType::S32);
1004 split_dim->size<loco::DataType::S32>(1);
1005 split_dim->scalar<loco::DataType::S32>() = 3;
1007 _splitv->size_splits(size_splits);
1008 _splitv->split_dim(split_dim);
1009 _splitv->num_split(3);
1012 _splitv_out1 = g->nodes()->create<luci::CircleSplitVOut>();
1013 _splitv_out1->shape({1, 2, 2, 32});
1014 _splitv_out1->dtype(loco::DataType::FLOAT32);
1015 _splitv_out1->index(0);
1016 _splitv_out1->input(_splitv);
1017 _splitv_out1->name("splitv_out1");
1020 _splitv_out2 = g->nodes()->create<luci::CircleSplitVOut>();
1021 _splitv_out2->shape({1, 2, 2, 32});
1022 _splitv_out2->dtype(loco::DataType::FLOAT32);
1023 _splitv_out2->index(1);
1024 _splitv_out2->input(_splitv);
1025 _splitv_out2->name("splitv_out2");
1028 _splitv_out3 = g->nodes()->create<luci::CircleSplitVOut>();
1029 _splitv_out3->shape({1, 2, 2, 128});
1030 _splitv_out3->dtype(loco::DataType::FLOAT32);
1031 _splitv_out3->index(2);
1032 _splitv_out3->input(_splitv);
1033 _splitv_out3->name("splitv_out3");
1037 luci::CircleSplitV *splitv() { return _splitv; }
1040 luci::CircleSplitV *_splitv = nullptr;
1041 luci::CircleSplitVOut *_splitv_out1 = nullptr;
1042 luci::CircleSplitVOut *_splitv_out2 = nullptr;
1043 luci::CircleSplitVOut *_splitv_out3 = nullptr;
1046 class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
1049 SplitVGraph() = default;
1053 TestIGraphlet::init(g(), {1, 2, 2, 192});
1054 TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
1055 SplitVGraphlet::init(g());
1058 _splitv->input(input());
1060 output(0)->from(_splitv_out1);
1061 output(1)->from(_splitv_out2);
1062 output(2)->from(_splitv_out3);
1066 class SquaredDifferenceGraph final : public SimpleGraph
1069 loco::Node *insertGraphBody(loco::Node *input) override
1071 sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
1074 sqdiff->name("sqdiff");
1080 luci::CircleSquaredDifference *sqdiff = nullptr;
1083 class SubGraph final : public SimpleGraph
1086 loco::Node *insertGraphBody(loco::Node *input) override
1088 sub = g.nodes()->create<luci::CircleSub>();
1089 beta = g.nodes()->create<luci::CircleConst>();
1091 sub->dtype(loco::DataType::FLOAT32);
1092 beta->dtype(loco::DataType::FLOAT32);
1094 uint32_t channel_size = 16;
1095 sub->shape({1, channel_size, 4, 4});
1096 beta->shape({1, channel_size, 1, 1});
1098 beta->size<loco::DataType::FLOAT32>(channel_size);
1099 for (uint32_t i = 0; i < channel_size; i++)
1101 beta->at<loco::DataType::FLOAT32>(i) = i;
1114 void update_const_shape_to_nchw(void)
1116 uint32_t channel_size = 16;
1117 beta->shape({1, channel_size, 4, 4});
1119 beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
1120 for (uint32_t i = 0; i < channel_size; i++)
1122 beta->at<loco::DataType::FLOAT32>(i) = i;
1127 luci::CircleSub *sub = nullptr;
1128 luci::CircleConst *beta = nullptr;
1131 class SubScalarGraph final : public SimpleGraph
1134 loco::Node *insertGraphBody(loco::Node *input) override
1136 sub = g.nodes()->create<luci::CircleSub>();
1137 beta = g.nodes()->create<luci::CircleConst>();
1139 sub->dtype(loco::DataType::FLOAT32);
1140 beta->dtype(loco::DataType::FLOAT32);
1142 uint32_t channel_size = 16;
1143 sub->shape({1, channel_size, 4, 4});
1146 beta->size<loco::DataType::FLOAT32>(1);
1147 beta->at<loco::DataType::FLOAT32>(0) = 5;
1159 luci::CircleSub *sub = nullptr;
1160 luci::CircleConst *beta = nullptr;
1163 void check_pre_trans(loco::Node *node)
1165 auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
1166 EXPECT_NE(nullptr, pre_trans);
1167 auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm());
1168 EXPECT_NE(nullptr, pre_trans_perm);
1169 EXPECT_EQ(1, pre_trans_perm->rank());
1170 EXPECT_EQ(4, pre_trans_perm->dim(0).value());
1171 EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype());
1172 EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0));
1173 EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1));
1174 EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2));
1175 EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3));
1178 void check_post_trans(loco::Node *node)
1180 auto post_trans = dynamic_cast<luci::CircleTranspose *>(node);
1181 EXPECT_NE(nullptr, post_trans);
1182 auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm());
1183 EXPECT_NE(nullptr, post_trans_perm);
1184 EXPECT_EQ(1, post_trans_perm->rank());
1185 EXPECT_EQ(4, post_trans_perm->dim(0).value());
1186 EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype());
1187 EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0));
1188 EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1));
1189 EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2));
1190 EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3));
1193 void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output)
1198 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
1202 std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
1204 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
1205 phase_runner.run(phase);
1210 TEST(ConvertNCHWToNHWCPassTest, name)
1212 luci::ConvertNCHWToNHWCPass pass(false, false);
1213 auto const name = pass.name();
1214 ASSERT_NE(nullptr, name);
1217 TEST(ConvertNCHWToNHWC, Add)
1222 run_phase(&g.g, false, false);
1224 auto input_succs = loco::succs(g.input);
1225 EXPECT_EQ(1, input_succs.size());
1226 check_post_trans(*input_succs.begin());
1228 check_pre_trans(g.add->x());
1230 auto add_succs = loco::succs(g.add);
1231 EXPECT_EQ(1, add_succs.size());
1232 check_post_trans(*add_succs.begin());
1234 uint32_t channel_size = 16;
1235 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1236 EXPECT_NE(nullptr, new_beta);
1237 EXPECT_EQ(4, new_beta->rank());
1238 EXPECT_EQ(1, new_beta->dim(0).value());
1239 EXPECT_EQ(1, new_beta->dim(1).value());
1240 EXPECT_EQ(1, new_beta->dim(2).value());
1241 EXPECT_EQ(channel_size, new_beta->dim(3).value());
1243 check_pre_trans(g.output->from());
1246 TEST(ConvertNCHWToNHWC, Add_NCHW_const)
1250 g.update_const_shape_to_nchw();
1252 run_phase(&g.g, false, false);
1254 check_pre_trans(g.add->x());
1256 auto add_succs = loco::succs(g.add);
1257 EXPECT_EQ(1, add_succs.size());
1258 check_post_trans(*add_succs.begin());
1260 uint32_t channel_size = 16;
1261 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1262 EXPECT_NE(nullptr, new_beta);
1263 EXPECT_EQ(4, new_beta->rank());
1264 EXPECT_EQ(1, new_beta->dim(0).value());
1265 EXPECT_EQ(4, new_beta->dim(1).value());
1266 EXPECT_EQ(4, new_beta->dim(2).value());
1267 EXPECT_EQ(channel_size, new_beta->dim(3).value());
1270 TEST(ConvertNCHWToNHWC, NHWC_Relu)
1272 // Relu is already NHWC, so it should not be converted
1273 // i.e., the graph is not changed
1277 run_phase(&g.g, false, false);
1279 EXPECT_EQ(g.pre_reshape, g.relu->features());
1281 auto relu_succs = loco::succs(g.relu);
1282 EXPECT_EQ(1, relu_succs.size());
1283 EXPECT_EQ(g.post_reshape, *relu_succs.begin());
1286 TEST(ConvertNCHWToNHWC, AddScalar)
1291 run_phase(&g.g, false, false);
1293 auto input_succs = loco::succs(g.input);
1294 EXPECT_EQ(1, input_succs.size());
1295 check_post_trans(*input_succs.begin());
1297 check_pre_trans(g.add->x());
1299 auto add_succs = loco::succs(g.add);
1300 EXPECT_EQ(1, add_succs.size());
1301 check_post_trans(*add_succs.begin());
1303 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1304 EXPECT_NE(nullptr, new_beta);
1305 EXPECT_EQ(4, new_beta->rank());
1306 EXPECT_EQ(1, new_beta->dim(0).value());
1307 EXPECT_EQ(1, new_beta->dim(1).value());
1308 EXPECT_EQ(1, new_beta->dim(2).value());
1309 EXPECT_EQ(1, new_beta->dim(3).value());
1311 check_pre_trans(g.output->from());
1314 TEST(ConvertNCHWToNHWC, Concatenation)
1316 ConcatenationGraph g;
1319 run_phase(&g.g, true, true);
1321 check_pre_trans(g.concat->values(0));
1322 check_pre_trans(g.concat->values(1));
1324 auto concat_succs = loco::succs(g.concat);
1325 EXPECT_EQ(1, concat_succs.size());
1326 check_post_trans(*concat_succs.begin());
1328 // Check concat shape, axis
1329 EXPECT_EQ(1, g.concat->dim(0).value());
1330 EXPECT_EQ(4, g.concat->dim(1).value());
1331 EXPECT_EQ(4, g.concat->dim(2).value());
1332 EXPECT_EQ(32, g.concat->dim(3).value());
1333 EXPECT_EQ(3, g.concat->axis());
1336 TEST(ConvertNCHWToNHWC, Elu)
1341 run_phase(&g.g, true, true);
1343 check_pre_trans(g.elu->features());
1345 auto elu_succs = loco::succs(g.elu);
1346 EXPECT_EQ(1, elu_succs.size());
1347 check_post_trans(*elu_succs.begin());
1350 EXPECT_EQ(1, g.elu->dim(0).value());
1351 EXPECT_EQ(4, g.elu->dim(1).value());
1352 EXPECT_EQ(4, g.elu->dim(2).value());
1353 EXPECT_EQ(16, g.elu->dim(3).value());
1356 TEST(ConvertNCHWToNHWC, LeakyRelu)
1361 run_phase(&g.g, true, true);
1363 check_pre_trans(g.leakyrelu->features());
1365 auto leakyrelu_succs = loco::succs(g.leakyrelu);
1366 EXPECT_EQ(1, leakyrelu_succs.size());
1367 check_post_trans(*leakyrelu_succs.begin());
1369 // Check leakyrelu shape
1370 EXPECT_EQ(1, g.leakyrelu->dim(0).value());
1371 EXPECT_EQ(4, g.leakyrelu->dim(1).value());
1372 EXPECT_EQ(4, g.leakyrelu->dim(2).value());
1373 EXPECT_EQ(16, g.leakyrelu->dim(3).value());
1376 TEST(ConvertNCHWToNHWC, Logistic)
1381 run_phase(&g.g, true, true);
1383 check_pre_trans(g.logistic->x());
1385 auto logistic_succs = loco::succs(g.logistic);
1386 EXPECT_EQ(1, logistic_succs.size());
1387 check_post_trans(*logistic_succs.begin());
1389 // Check logistic shape
1390 EXPECT_EQ(1, g.logistic->dim(0).value());
1391 EXPECT_EQ(4, g.logistic->dim(1).value());
1392 EXPECT_EQ(4, g.logistic->dim(2).value());
1393 EXPECT_EQ(16, g.logistic->dim(3).value());
1396 TEST(ConvertNCHWToNHWC, Maximum)
1401 run_phase(&g.g, false, false);
1403 auto input_succs = loco::succs(g.input);
1404 EXPECT_EQ(1, input_succs.size());
1405 check_post_trans(*input_succs.begin());
1407 check_pre_trans(g.max->x());
1409 auto max_succs = loco::succs(g.max);
1410 EXPECT_EQ(1, max_succs.size());
1411 check_post_trans(*max_succs.begin());
1413 check_pre_trans(g.output->from());
1416 TEST(ConvertNCHWToNHWC, Maximum_non_scalar_NEG)
1421 g.limit->shape({3});
1423 luci::ConvertNCHWToNHWCPass pass(true, true);
1424 EXPECT_FALSE(pass.run(&g.g));
1427 TEST(ConvertNCHWToNHWC, MaximumNonConst)
1429 MaximumNonConstGraph g;
1432 run_phase(&g.g, true, true);
1434 check_pre_trans(g.max->x());
1435 check_pre_trans(g.max->y());
1437 auto max_succs = loco::succs(g.max);
1438 EXPECT_EQ(1, max_succs.size());
1439 check_post_trans(*max_succs.begin());
1442 TEST(ConvertNCHWToNHWC, Mean)
1447 run_phase(&g.g, false, false);
1449 check_pre_trans(g.mean->input());
1451 auto mean_succs = loco::succs(g.mean);
1452 EXPECT_EQ(1, mean_succs.size());
1453 check_post_trans(*mean_succs.begin());
1455 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1456 EXPECT_NE(nullptr, new_rindices);
1457 EXPECT_EQ(1, new_rindices->rank());
1458 EXPECT_EQ(2, new_rindices->dim(0).value());
1459 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1460 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1461 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1464 TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
1468 std::vector<int32_t> nchw_ind;
1469 std::vector<int32_t> nhwc_ind;
1470 std::initializer_list<uint32_t> shape;
1471 bool needs_transpose = false;
1479 std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
1480 {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
1481 {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
1482 {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
1483 {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
1484 {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1486 for (auto &tc : test_cases)
1490 g.axes(tc.nchw_ind);
1494 run_phase(&g.g, false, true);
1496 check_pre_trans(g.mean->input());
1498 auto mean_succs = loco::succs(g.mean);
1499 EXPECT_EQ(1, mean_succs.size());
1500 if (tc.needs_transpose)
1502 EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
1506 EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
1509 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1510 EXPECT_NE(nullptr, new_rindices);
1511 EXPECT_EQ(1, new_rindices->rank());
1512 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1513 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1514 for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1516 EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1521 TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
1524 auto input = g.nodes()->create<luci::CircleInput>();
1525 auto output = g.nodes()->create<luci::CircleOutput>();
1526 input->name("input");
1527 output->name("output");
1529 auto graph_input = g.inputs()->create();
1530 input->index(graph_input->index());
1531 auto graph_output = g.outputs()->create();
1532 output->index(graph_output->index());
1534 graph_input->dtype(loco::DataType::FLOAT32);
1535 input->dtype(loco::DataType::FLOAT32);
1536 output->dtype(loco::DataType::FLOAT32);
1537 graph_output->dtype(loco::DataType::FLOAT32);
1539 uint32_t channel_size = 16;
1540 graph_input->shape({channel_size, 4, 4});
1541 input->shape({channel_size, 4, 4});
1542 output->shape({channel_size});
1543 graph_output->shape({channel_size});
1545 auto mean = g.nodes()->create<luci::CircleMean>();
1546 auto rindices = g.nodes()->create<luci::CircleConst>();
1548 mean->dtype(loco::DataType::FLOAT32);
1549 rindices->dtype(loco::DataType::S32);
1551 mean->shape({channel_size});
1552 rindices->shape({2});
1554 rindices->size<loco::DataType::S32>(2);
1555 rindices->at<loco::DataType::S32>(0) = 1;
1556 rindices->at<loco::DataType::S32>(1) = 2;
1559 mean->reduction_indices(rindices);
1560 mean->keep_dims(false);
1563 rindices->name("rindices");
1567 run_phase(&g, true, true);
1569 auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
1570 EXPECT_NE(nullptr, new_rindices);
1571 EXPECT_EQ(1, new_rindices->rank());
1572 EXPECT_EQ(2, new_rindices->dim(0).value());
1573 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1574 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1575 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1578 TEST(ConvertNCHWToNHWC, Minimum)
1583 run_phase(&g.g, false, false);
1585 auto input_succs = loco::succs(g.input);
1586 EXPECT_EQ(1, input_succs.size());
1587 check_post_trans(*input_succs.begin());
1589 check_pre_trans(g.min->x());
1591 auto min_succs = loco::succs(g.min);
1592 EXPECT_EQ(1, min_succs.size());
1593 check_post_trans(*min_succs.begin());
1595 check_pre_trans(g.output->from());
1598 TEST(ConvertNCHWToNHWC, Minimum_non_scalar_NEG)
1603 g.limit->shape({3});
1605 luci::ConvertNCHWToNHWCPass pass(true, true);
1606 EXPECT_FALSE(pass.run(&g.g));
1609 TEST(ConvertNCHWToNHWC, Mul)
1614 run_phase(&g.g, false, false);
1616 auto input_succs = loco::succs(g.input);
1617 EXPECT_EQ(1, input_succs.size());
1618 check_post_trans(*input_succs.begin());
1620 check_pre_trans(g.mul->x());
1622 auto mul_succs = loco::succs(g.mul);
1623 EXPECT_EQ(1, mul_succs.size());
1624 check_post_trans(*mul_succs.begin());
1626 uint32_t channel_size = 16;
1627 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1628 EXPECT_NE(nullptr, new_multiplier);
1629 EXPECT_EQ(4, new_multiplier->rank());
1630 EXPECT_EQ(1, new_multiplier->dim(0).value());
1631 EXPECT_EQ(1, new_multiplier->dim(1).value());
1632 EXPECT_EQ(1, new_multiplier->dim(2).value());
1633 EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1635 check_pre_trans(g.output->from());
1638 TEST(ConvertNCHWToNHWC, Mul_NCHW_const)
1642 g.update_const_shape_to_nchw();
1644 run_phase(&g.g, false, false);
1646 check_pre_trans(g.mul->x());
1648 auto mul_succs = loco::succs(g.mul);
1649 EXPECT_EQ(1, mul_succs.size());
1650 check_post_trans(*mul_succs.begin());
1652 uint32_t channel_size = 16;
1653 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1654 EXPECT_NE(nullptr, new_multiplier);
1655 EXPECT_EQ(4, new_multiplier->rank());
1656 EXPECT_EQ(1, new_multiplier->dim(0).value());
1657 EXPECT_EQ(4, new_multiplier->dim(1).value());
1658 EXPECT_EQ(4, new_multiplier->dim(2).value());
1659 EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1662 TEST(ConvertNCHWToNHWC, MulScalar)
1667 run_phase(&g.g, false, false);
1669 auto input_succs = loco::succs(g.input);
1670 EXPECT_EQ(1, input_succs.size());
1671 check_post_trans(*input_succs.begin());
1673 check_pre_trans(g.mul->x());
1675 auto mul_succs = loco::succs(g.mul);
1676 EXPECT_EQ(1, mul_succs.size());
1677 check_post_trans(*mul_succs.begin());
1679 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1680 EXPECT_NE(nullptr, new_multiplier);
1681 EXPECT_EQ(4, new_multiplier->rank());
1682 EXPECT_EQ(1, new_multiplier->dim(0).value());
1683 EXPECT_EQ(1, new_multiplier->dim(1).value());
1684 EXPECT_EQ(1, new_multiplier->dim(2).value());
1685 EXPECT_EQ(1, new_multiplier->dim(3).value());
1687 check_pre_trans(g.output->from());
1690 TEST(ConvertNCHWToNHWC, MulBothNorm)
1695 run_phase(&g.g, false, false);
1697 auto input_succs = loco::succs(g.input);
1698 EXPECT_EQ(1, input_succs.size());
1699 check_post_trans(*input_succs.begin());
1701 check_pre_trans(g.mul->x());
1702 check_pre_trans(g.mul->y());
1704 auto mul_succs = loco::succs(g.mul);
1705 EXPECT_EQ(1, mul_succs.size());
1706 check_post_trans(*mul_succs.begin());
1708 check_pre_trans(g.output->from());
1711 TEST(ConvertNCHWToNHWC, Neg)
1716 run_phase(&g.g, true, true);
1718 check_pre_trans(g.neg->x());
1720 auto neg_succs = loco::succs(g.neg);
1721 EXPECT_EQ(1, neg_succs.size());
1722 check_post_trans(*neg_succs.begin());
1724 // Check leakyrelu shape
1725 EXPECT_EQ(1, g.neg->dim(0).value());
1726 EXPECT_EQ(4, g.neg->dim(1).value());
1727 EXPECT_EQ(4, g.neg->dim(2).value());
1728 EXPECT_EQ(16, g.neg->dim(3).value());
1731 TEST(ConvertNCHWToNHWC, Pad)
1736 run_phase(&g.g, false, false);
1738 auto input_succs = loco::succs(g.input);
1739 EXPECT_EQ(1, input_succs.size());
1740 check_post_trans(*input_succs.begin());
1742 check_pre_trans(g.pad->input());
1744 auto pad_succs = loco::succs(g.pad);
1745 EXPECT_EQ(1, pad_succs.size());
1746 check_post_trans(*pad_succs.begin());
1748 auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1749 EXPECT_NE(nullptr, new_paddings);
1750 EXPECT_EQ(2, new_paddings->rank());
1751 EXPECT_EQ(4, new_paddings->dim(0).value());
1752 EXPECT_EQ(2, new_paddings->dim(1).value());
1753 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1754 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1755 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1756 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1757 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1758 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1759 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1760 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1762 check_pre_trans(g.output->from());
1765 TEST(ConvertNCHWToNHWC, PadV2)
1770 run_phase(&g.g, false, false);
1772 check_pre_trans(g.pad->input());
1774 auto pad_succs = loco::succs(g.pad);
1775 EXPECT_EQ(1, pad_succs.size());
1776 check_post_trans(*pad_succs.begin());
1778 auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1779 EXPECT_NE(nullptr, new_paddings);
1780 EXPECT_EQ(2, new_paddings->rank());
1781 EXPECT_EQ(4, new_paddings->dim(0).value());
1782 EXPECT_EQ(2, new_paddings->dim(1).value());
1783 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1784 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1785 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1786 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1787 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1788 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1789 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1790 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1793 TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
1799 g.input->dim(0).unset();
1800 g.add->dim(0).unset();
1801 g.output->dim(0).unset();
1803 luci::ConvertNCHWToNHWCPass pass(false, false);
1804 EXPECT_EQ(false, pass.run(&g.g));
1807 TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
1814 run_phase(&g.g, true, false);
1816 // Check input shape
1817 EXPECT_EQ(1, g.input->dim(0).value());
1818 EXPECT_EQ(16, g.input->dim(1).value());
1819 EXPECT_EQ(4, g.input->dim(2).value());
1820 EXPECT_EQ(4, g.input->dim(3).value());
1822 // Check output shape
1823 EXPECT_EQ(1, g.output->dim(0).value());
1824 EXPECT_EQ(4, g.output->dim(1).value());
1825 EXPECT_EQ(4, g.output->dim(2).value());
1826 EXPECT_EQ(16, g.output->dim(3).value());
1834 run_phase(&g.g, false, true);
1836 // Check input shape
1837 EXPECT_EQ(1, g.input->dim(0).value());
1838 EXPECT_EQ(4, g.input->dim(1).value());
1839 EXPECT_EQ(4, g.input->dim(2).value());
1840 EXPECT_EQ(16, g.input->dim(3).value());
1842 // Check output shape
1843 EXPECT_EQ(1, g.output->dim(0).value());
1844 EXPECT_EQ(16, g.output->dim(1).value());
1845 EXPECT_EQ(4, g.output->dim(2).value());
1846 EXPECT_EQ(4, g.output->dim(3).value());
1849 // Preserve both input and output
1854 run_phase(&g.g, true, true);
1856 // Check input shape
1857 EXPECT_EQ(1, g.input->dim(0).value());
1858 EXPECT_EQ(16, g.input->dim(1).value());
1859 EXPECT_EQ(4, g.input->dim(2).value());
1860 EXPECT_EQ(4, g.input->dim(3).value());
1862 // Check output shape
1863 EXPECT_EQ(1, g.output->dim(0).value());
1864 EXPECT_EQ(16, g.output->dim(1).value());
1865 EXPECT_EQ(4, g.output->dim(2).value());
1866 EXPECT_EQ(4, g.output->dim(3).value());
1870 TEST(ConvertNCHWToNHWC, ReduceMax)
1875 run_phase(&g.g, false, false);
1877 check_pre_trans(g.rm->input());
1879 auto rm_succs = loco::succs(g.rm);
1880 EXPECT_EQ(1, rm_succs.size());
1881 check_post_trans(*rm_succs.begin());
1883 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1884 EXPECT_NE(nullptr, new_rindices);
1885 EXPECT_EQ(1, new_rindices->rank());
1886 EXPECT_EQ(2, new_rindices->dim(0).value());
1887 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1888 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1889 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1892 TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false)
1896 std::vector<int32_t> nchw_ind;
1897 std::vector<int32_t> nhwc_ind;
1898 std::initializer_list<uint32_t> shape;
1899 bool needs_transpose = false;
1907 std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
1908 {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
1909 {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
1910 {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
1911 {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
1912 {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1914 for (auto &tc : test_cases)
1918 g.axes(tc.nchw_ind);
1922 run_phase(&g.g, true, true);
1924 check_pre_trans(g.rm->input());
1926 auto rm_succs = loco::succs(g.rm);
1927 EXPECT_EQ(1, rm_succs.size());
1928 if (tc.needs_transpose)
1930 EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
1934 EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
1937 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1938 EXPECT_NE(nullptr, new_rindices);
1939 EXPECT_EQ(1, new_rindices->rank());
1940 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1941 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1942 for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1944 EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1949 TEST(ConvertNCHWToNHWC, ReduceMin)
1954 run_phase(&g.g, true, true);
1956 check_pre_trans(g.rm->input());
1958 auto rm_succs = loco::succs(g.rm);
1959 EXPECT_EQ(1, rm_succs.size());
1960 check_post_trans(*rm_succs.begin());
1962 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1963 EXPECT_NE(nullptr, new_rindices);
1964 EXPECT_EQ(1, new_rindices->rank());
1965 EXPECT_EQ(2, new_rindices->dim(0).value());
1966 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1967 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1968 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1971 TEST(ConvertNCHWToNHWC, ReduceMin_keep_dims_false)
1975 std::vector<int32_t> nchw_ind;
1976 std::vector<int32_t> nhwc_ind;
1977 std::initializer_list<uint32_t> shape;
1978 bool needs_transpose = false;
1986 std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
1987 {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
1988 {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
1989 {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
1990 {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
1991 {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1993 for (auto &tc : test_cases)
1997 g.axes(tc.nchw_ind);
2001 run_phase(&g.g, true, true);
2003 check_pre_trans(g.rm->input());
2005 auto rm_succs = loco::succs(g.rm);
2006 EXPECT_EQ(1, rm_succs.size());
2007 if (tc.needs_transpose)
2009 EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
2013 EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
2016 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
2017 EXPECT_NE(nullptr, new_rindices);
2018 EXPECT_EQ(1, new_rindices->rank());
2019 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
2020 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
2021 for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
2023 EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
2028 TEST(ConvertNCHWToNHWC, Relu)
2033 run_phase(&g.g, true, true);
2035 check_pre_trans(g.relu->features());
2037 auto relu_succs = loco::succs(g.relu);
2038 EXPECT_EQ(1, relu_succs.size());
2039 check_post_trans(*relu_succs.begin());
2042 EXPECT_EQ(1, g.relu->dim(0).value());
2043 EXPECT_EQ(4, g.relu->dim(1).value());
2044 EXPECT_EQ(4, g.relu->dim(2).value());
2045 EXPECT_EQ(16, g.relu->dim(3).value());
2048 TEST(ConvertNCHWToNHWC, Relu6)
2053 run_phase(&g.g, true, true);
2055 check_pre_trans(g.relu6->features());
2057 auto relu6_succs = loco::succs(g.relu6);
2058 EXPECT_EQ(1, relu6_succs.size());
2059 check_post_trans(*relu6_succs.begin());
2061 // Check relu6 shape
2062 EXPECT_EQ(1, g.relu6->dim(0).value());
2063 EXPECT_EQ(4, g.relu6->dim(1).value());
2064 EXPECT_EQ(4, g.relu6->dim(2).value());
2065 EXPECT_EQ(16, g.relu6->dim(3).value());
2068 TEST(ConvertNCHWToNHWC, Rsqrt)
2073 run_phase(&g.g, true, true);
2075 check_pre_trans(g.rsqrt->x());
2077 auto rsqrt_succs = loco::succs(g.rsqrt);
2078 EXPECT_EQ(1, rsqrt_succs.size());
2079 check_post_trans(*rsqrt_succs.begin());
2081 // Check rsqrt shape
2082 EXPECT_EQ(1, g.rsqrt->dim(0).value());
2083 EXPECT_EQ(4, g.rsqrt->dim(1).value());
2084 EXPECT_EQ(4, g.rsqrt->dim(2).value());
2085 EXPECT_EQ(16, g.rsqrt->dim(3).value());
2088 TEST(ConvertNCHWToNHWC, SplitV)
2093 run_phase(g.g(), true, true);
2095 check_pre_trans(g.splitv()->input());
2097 auto splitv_succs = loco::succs(g.splitv());
2098 for (auto svo : loco::succs(g.splitv()))
2100 for (auto succ : loco::succs(svo))
2102 check_post_trans(succ);
2106 // Check splitv() shape
2107 EXPECT_EQ(1, g.splitv()->dim(0).value());
2108 EXPECT_EQ(2, g.splitv()->dim(1).value());
2109 EXPECT_EQ(192, g.splitv()->dim(2).value());
2110 EXPECT_EQ(2, g.splitv()->dim(3).value());
2113 auto axis = dynamic_cast<luci::CircleConst *>(g.splitv()->split_dim());
2114 EXPECT_NE(nullptr, axis);
2115 EXPECT_EQ(1, axis->size<loco::DataType::S32>());
2116 EXPECT_EQ(2, axis->at<loco::DataType::S32>(0));
2119 TEST(ConvertNCHWToNHWC, SquaredDifference)
2121 SquaredDifferenceGraph g;
2124 run_phase(&g.g, true, true);
2126 check_pre_trans(g.sqdiff->x());
2127 check_pre_trans(g.sqdiff->y());
2129 auto sqdiff_succs = loco::succs(g.sqdiff);
2130 EXPECT_EQ(1, sqdiff_succs.size());
2131 check_post_trans(*sqdiff_succs.begin());
2134 TEST(ConvertNCHWToNHWC, Sub)
2139 run_phase(&g.g, false, false);
2141 auto input_succs = loco::succs(g.input);
2142 EXPECT_EQ(1, input_succs.size());
2143 check_post_trans(*input_succs.begin());
2145 check_pre_trans(g.sub->x());
2147 auto add_succs = loco::succs(g.sub);
2148 EXPECT_EQ(1, add_succs.size());
2149 check_post_trans(*add_succs.begin());
2151 uint32_t channel_size = 16;
2152 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2153 EXPECT_NE(nullptr, new_beta);
2154 EXPECT_EQ(4, new_beta->rank());
2155 EXPECT_EQ(1, new_beta->dim(0).value());
2156 EXPECT_EQ(1, new_beta->dim(1).value());
2157 EXPECT_EQ(1, new_beta->dim(2).value());
2158 EXPECT_EQ(channel_size, new_beta->dim(3).value());
2160 check_pre_trans(g.output->from());
2163 TEST(ConvertNCHWToNHWC, Sub_NCHW_const)
2167 g.update_const_shape_to_nchw();
2169 run_phase(&g.g, false, false);
2171 check_pre_trans(g.sub->x());
2173 auto sub_succs = loco::succs(g.sub);
2174 EXPECT_EQ(1, sub_succs.size());
2175 check_post_trans(*sub_succs.begin());
2177 uint32_t channel_size = 16;
2178 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2179 EXPECT_NE(nullptr, new_beta);
2180 EXPECT_EQ(4, new_beta->rank());
2181 EXPECT_EQ(1, new_beta->dim(0).value());
2182 EXPECT_EQ(4, new_beta->dim(1).value());
2183 EXPECT_EQ(4, new_beta->dim(2).value());
2184 EXPECT_EQ(channel_size, new_beta->dim(3).value());
2187 TEST(ConvertNCHWToNHWC, SubScalar)
2192 run_phase(&g.g, false, false);
2194 auto input_succs = loco::succs(g.input);
2195 EXPECT_EQ(1, input_succs.size());
2196 check_post_trans(*input_succs.begin());
2198 check_pre_trans(g.sub->y());
2200 auto add_succs = loco::succs(g.sub);
2201 EXPECT_EQ(1, add_succs.size());
2202 check_post_trans(*add_succs.begin());
2204 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
2205 EXPECT_NE(nullptr, new_beta);
2206 EXPECT_EQ(1, new_beta->rank());
2208 check_pre_trans(g.output->from());
2211 TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG)
2213 NoPostReshapeGraph g;
2216 run_phase(&g.g, true, true);
2218 check_pre_trans(g.relu->features());
2220 auto relu_succs = loco::succs(g.relu);
2221 EXPECT_EQ(1, relu_succs.size());
2222 check_post_trans(*relu_succs.begin());
2225 TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG)
2227 ReluNotClosedGraph g;
2230 run_phase(&g.g, true, true);
2232 check_pre_trans(g.relu->features());
2234 auto relu_succs = loco::succs(g.relu);
2235 EXPECT_EQ(1, relu_succs.size());
2236 check_post_trans(*relu_succs.begin());