Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ConvertNCHWToNHWCPass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <logo/Phase.h>
18
19 #include <luci/test/TestIOGraph.h>
20
21 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
22 #include "luci/Pass/CircleShapeInferencePass.h"
23
24 #include <luci/IR/CircleNodes.h>
25
26 #include <gtest/gtest.h>
27
28 using namespace luci::test;
29
30 namespace
31 {
32
33 /**
34  *  Graph with a single Op (example: Add).
35  *
36  *  BEFORE
37  *  - All Ops including Input/Output are NCHW.
38  *
39  *             [Input] [beta]
40  *                |  /
41  *              [Add]
42  *                |
43  *             [Output]
44  *
45  *  AFTER
46  *  - All Ops including Input/Output are NHWC.
47  *
48  *             [Input]
49  *                |
50  *         [Transpose]
51  *                |
52  *        [Transpose] [beta]
53  *                |  /
54  *              [Add]
55  *                |
56  *         [Transpose]
57  *                |
58  *         [Transpose]
59  *                |
60  *             [Output]
61  */
62 class SimpleGraph
63 {
64 public:
65   SimpleGraph() = default;
66
67 public:
68   void init()
69   {
70     input = g.nodes()->create<luci::CircleInput>();
71     output = g.nodes()->create<luci::CircleOutput>();
72     input->name("input");
73     output->name("output");
74
75     auto graph_input = g.inputs()->create();
76     input->index(graph_input->index());
77     auto graph_output = g.outputs()->create();
78     output->index(graph_output->index());
79
80     graph_input->dtype(loco::DataType::FLOAT32);
81     input->dtype(loco::DataType::FLOAT32);
82     output->dtype(loco::DataType::FLOAT32);
83     graph_output->dtype(loco::DataType::FLOAT32);
84
85     uint32_t channel_size = 16;
86     graph_input->shape({1, channel_size, 4, 4});
87     input->shape({1, channel_size, 4, 4});
88     output->shape({1, channel_size, 4, 4});
89     graph_output->shape({1, channel_size, 4, 4});
90
91     auto graph_body = insertGraphBody(input);
92     output->from(graph_body);
93   }
94
95   virtual ~SimpleGraph() = default;
96
97 protected:
98   virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
99
100 public:
101   loco::Graph g;
102   luci::CircleInput *input = nullptr;
103   luci::CircleOutput *output = nullptr;
104 };
105
106 class AddGraph final : public SimpleGraph
107 {
108 protected:
109   loco::Node *insertGraphBody(loco::Node *input) override
110   {
111     add = g.nodes()->create<luci::CircleAdd>();
112     beta = g.nodes()->create<luci::CircleConst>();
113
114     add->dtype(loco::DataType::FLOAT32);
115     beta->dtype(loco::DataType::FLOAT32);
116
117     uint32_t channel_size = 16;
118     add->shape({1, channel_size, 4, 4});
119     beta->shape({1, channel_size, 1, 1});
120
121     beta->size<loco::DataType::FLOAT32>(channel_size);
122     for (uint32_t i = 0; i < channel_size; i++)
123     {
124       beta->at<loco::DataType::FLOAT32>(i) = i;
125     }
126
127     add->x(input);
128     add->y(beta);
129
130     add->name("add");
131     beta->name("beta");
132
133     return add;
134   }
135
136 public:
137   void update_const_shape_to_nchw(void)
138   {
139     uint32_t channel_size = 16;
140     beta->shape({1, channel_size, 4, 4});
141
142     beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
143     for (uint32_t i = 0; i < channel_size; i++)
144     {
145       beta->at<loco::DataType::FLOAT32>(i) = i;
146     }
147   }
148
149 public:
150   luci::CircleAdd *add = nullptr;
151   luci::CircleConst *beta = nullptr;
152 };
153
154 class NHWCReluGraph final : public SimpleGraph
155 {
156 protected:
157   loco::Node *insertGraphBody(loco::Node *input) override
158   {
159     relu = g.nodes()->create<luci::CircleRelu>();
160     pre_reshape = g.nodes()->create<luci::CircleReshape>();
161     post_reshape = g.nodes()->create<luci::CircleReshape>();
162     pre_shape = g.nodes()->create<luci::CircleConst>();
163     post_shape = g.nodes()->create<luci::CircleConst>();
164
165     pre_shape->dtype(loco::DataType::S32);
166     post_shape->dtype(loco::DataType::S32);
167
168     uint32_t channel_size = 16;
169     auto in = loco::must_cast<luci::CircleNode *>(input);
170     in->shape({1, channel_size, 4, 4});
171     pre_shape->shape({4});
172     post_shape->shape({4});
173
174     pre_shape->size<loco::DataType::S32>(4);
175     pre_shape->at<loco::DataType::S32>(0) = 1;
176     pre_shape->at<loco::DataType::S32>(1) = 4;
177     pre_shape->at<loco::DataType::S32>(2) = 4;
178     pre_shape->at<loco::DataType::S32>(3) = channel_size;
179
180     post_shape->size<loco::DataType::S32>(4);
181     post_shape->at<loco::DataType::S32>(0) = 1;
182     post_shape->at<loco::DataType::S32>(1) = channel_size;
183     post_shape->at<loco::DataType::S32>(2) = 4;
184     post_shape->at<loco::DataType::S32>(3) = 4;
185
186     pre_reshape->tensor(input);
187     pre_reshape->shape(pre_shape);
188
189     relu->features(pre_reshape);
190
191     post_reshape->tensor(relu);
192     post_reshape->shape(post_shape);
193
194     relu->name("Relu");
195     pre_reshape->name("pre-reshape");
196     post_reshape->name("post-reshape");
197
198     return post_reshape;
199   }
200
201 public:
202   luci::CircleRelu *relu = nullptr;
203   luci::CircleReshape *pre_reshape = nullptr;
204   luci::CircleReshape *post_reshape = nullptr;
205   luci::CircleConst *pre_shape = nullptr;
206   luci::CircleConst *post_shape = nullptr;
207 };
208
209 /**
210  *  Graph with pre-Reshape but no post-Transpose/Reshape.
211  *
212  *  BEFORE
213  *             [Input]
214  *                |
215  *          [Pre-Reshape]
216  *                |
217  *              [Relu]
218  *                |
219  *             [Output]
220  *
221  *  AFTER
222  *             [Input]
223  *                |
224  *          [Pre-Reshape]
225  *                |
226  *          [Pre-Transpose]
227  *                |
228  *              [Relu]
229  *                |
230  *          [Post-Transpose]
231  *                |
232  *             [Output]
233  */
234 class NoPostReshapeGraph final : public SimpleGraph
235 {
236 protected:
237   loco::Node *insertGraphBody(loco::Node *input) override
238   {
239     relu = g.nodes()->create<luci::CircleRelu>();
240     pre_reshape = g.nodes()->create<luci::CircleReshape>();
241     pre_shape = g.nodes()->create<luci::CircleConst>();
242
243     pre_shape->dtype(loco::DataType::S32);
244
245     uint32_t channel_size = 16;
246     auto in = loco::must_cast<luci::CircleNode *>(input);
247     in->shape({1, channel_size, 4, 4});
248     pre_shape->shape({4});
249
250     pre_shape->size<loco::DataType::S32>(4);
251     pre_shape->at<loco::DataType::S32>(0) = 1;
252     pre_shape->at<loco::DataType::S32>(1) = 4;
253     pre_shape->at<loco::DataType::S32>(2) = 4;
254     pre_shape->at<loco::DataType::S32>(3) = channel_size;
255
256     pre_reshape->tensor(input);
257     pre_reshape->shape(pre_shape);
258     relu->features(pre_reshape);
259
260     relu->name("Relu");
261     pre_reshape->name("pre-reshape");
262
263     return relu;
264   }
265
266 public:
267   luci::CircleRelu *relu = nullptr;
268   luci::CircleReshape *pre_reshape = nullptr;
269   luci::CircleConst *pre_shape = nullptr;
270 };
271
272 /**
273  *  Graph with two pre-Reshapes
274  *
275  *  BEFORE
276  *             [Input]
277  *                |
278  *          [Pre-Reshape]
279  *                |
280  *              [Relu]
281  *                |
282  *          [Pre-Reshape]
283  *                |
284  *          [Post-Reshape]
285  *                |
286  *             [Output]
287  *
288  *  AFTER
289  *             [Input]
290  *                |
291  *          [Pre-Reshape]
292  *                |
293  *          [Pre-Transpose]
294  *                |
295  *              [Relu]
296  *                |
297  *          [Post-Transpose]
298  *                |
299  *          [Pre-Reshape]
300  *                |
301  *          [Post-Reshape]
302  *                |
303  *             [Output]
304  */
305 class ReluNotClosedGraph final : public SimpleGraph
306 {
307 protected:
308   loco::Node *insertGraphBody(loco::Node *input) override
309   {
310     relu = g.nodes()->create<luci::CircleRelu>();
311     pre_reshape = g.nodes()->create<luci::CircleReshape>();
312     pre_reshape_2 = g.nodes()->create<luci::CircleReshape>();
313     post_reshape = g.nodes()->create<luci::CircleReshape>();
314     pre_shape = g.nodes()->create<luci::CircleConst>();
315     pre_shape_2 = g.nodes()->create<luci::CircleConst>();
316     post_shape = g.nodes()->create<luci::CircleConst>();
317
318     pre_shape->dtype(loco::DataType::S32);
319     pre_shape_2->dtype(loco::DataType::S32);
320     post_shape->dtype(loco::DataType::S32);
321
322     uint32_t channel_size = 16;
323     auto in = loco::must_cast<luci::CircleNode *>(input);
324     in->shape({1, channel_size, 4, 4});
325     pre_shape->shape({4});
326     pre_shape_2->shape({4});
327     post_shape->shape({4});
328
329     pre_shape->size<loco::DataType::S32>(4);
330     pre_shape->at<loco::DataType::S32>(0) = 1;
331     pre_shape->at<loco::DataType::S32>(1) = 4;
332     pre_shape->at<loco::DataType::S32>(2) = 4;
333     pre_shape->at<loco::DataType::S32>(3) = channel_size;
334
335     pre_shape_2->size<loco::DataType::S32>(4);
336     pre_shape_2->at<loco::DataType::S32>(0) = 1;
337     pre_shape_2->at<loco::DataType::S32>(1) = 4;
338     pre_shape_2->at<loco::DataType::S32>(2) = channel_size;
339     pre_shape_2->at<loco::DataType::S32>(3) = 4;
340
341     post_shape->size<loco::DataType::S32>(4);
342     post_shape->at<loco::DataType::S32>(0) = 1;
343     post_shape->at<loco::DataType::S32>(1) = 4;
344     post_shape->at<loco::DataType::S32>(2) = 4;
345     post_shape->at<loco::DataType::S32>(3) = channel_size;
346
347     pre_reshape->tensor(input);
348     pre_reshape->shape(pre_shape);
349
350     relu->features(pre_reshape);
351
352     pre_reshape_2->tensor(relu);
353     pre_reshape_2->shape(pre_shape_2);
354
355     post_reshape->tensor(pre_reshape_2);
356     post_reshape->shape(post_shape);
357
358     relu->name("Relu");
359     pre_reshape->name("pre-reshape");
360     pre_reshape->name("pre-reshape-2");
361     post_reshape->name("post-reshape");
362
363     return post_reshape;
364   }
365
366 public:
367   luci::CircleRelu *relu = nullptr;
368   luci::CircleReshape *pre_reshape = nullptr;
369   luci::CircleReshape *pre_reshape_2 = nullptr;
370   luci::CircleReshape *post_reshape = nullptr;
371   luci::CircleConst *pre_shape = nullptr;
372   luci::CircleConst *pre_shape_2 = nullptr;
373   luci::CircleConst *post_shape = nullptr;
374 };
375
376 class AddScalarGraph final : public SimpleGraph
377 {
378 protected:
379   loco::Node *insertGraphBody(loco::Node *input) override
380   {
381     add = g.nodes()->create<luci::CircleAdd>();
382     beta = g.nodes()->create<luci::CircleConst>();
383
384     add->dtype(loco::DataType::FLOAT32);
385     beta->dtype(loco::DataType::FLOAT32);
386
387     uint32_t channel_size = 16;
388     add->shape({1, channel_size, 4, 4});
389     beta->shape({1});
390
391     beta->size<loco::DataType::FLOAT32>(1);
392     beta->at<loco::DataType::FLOAT32>(0) = 3.14;
393
394     add->x(input);
395     add->y(beta);
396
397     add->name("add");
398     beta->name("beta");
399
400     return add;
401   }
402
403 public:
404   luci::CircleAdd *add = nullptr;
405   luci::CircleConst *beta = nullptr;
406 };
407
408 class ConcatenationGraph final : public SimpleGraph
409 {
410 protected:
411   loco::Node *insertGraphBody(loco::Node *input) override
412   {
413     concat = g.nodes()->create<luci::CircleConcatenation>(2);
414     concat->values(0, input);
415     concat->axis(1);
416
417     input2 = g.nodes()->create<luci::CircleConst>();
418     input2->dtype(loco::DataType::FLOAT32);
419     input2->shape({1, 16, 4, 4});
420     input2->size<loco::DataType::FLOAT32>(16 * 4 * 4);
421     for (uint32_t i = 0; i < 16 * 4 * 4; i++)
422     {
423       input2->at<loco::DataType::FLOAT32>(i) = i;
424     }
425     concat->values(1, input2);
426
427     concat->name("concat");
428     input2->name("input2");
429
430     return concat;
431   }
432
433 public:
434   luci::CircleConcatenation *concat = nullptr;
435   luci::CircleConst *input2 = nullptr;
436 };
437
438 class EluGraph final : public SimpleGraph
439 {
440 protected:
441   loco::Node *insertGraphBody(loco::Node *input) override
442   {
443     elu = g.nodes()->create<luci::CircleElu>();
444     elu->features(input);
445     elu->name("elu");
446
447     return elu;
448   }
449
450 public:
451   luci::CircleElu *elu = nullptr;
452 };
453
454 class LeakyReluGraph final : public SimpleGraph
455 {
456 protected:
457   loco::Node *insertGraphBody(loco::Node *input) override
458   {
459     leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>();
460     leakyrelu->features(input);
461     leakyrelu->name("leakyrelu");
462
463     return leakyrelu;
464   }
465
466 public:
467   luci::CircleLeakyRelu *leakyrelu = nullptr;
468 };
469
470 class LogisticGraph final : public SimpleGraph
471 {
472 protected:
473   loco::Node *insertGraphBody(loco::Node *input) override
474   {
475     logistic = g.nodes()->create<luci::CircleLogistic>();
476     logistic->x(input);
477     logistic->name("logistic");
478
479     return logistic;
480   }
481
482 public:
483   luci::CircleLogistic *logistic = nullptr;
484 };
485
486 class MaximumGraph final : public SimpleGraph
487 {
488 protected:
489   loco::Node *insertGraphBody(loco::Node *input) override
490   {
491     max = g.nodes()->create<luci::CircleMaximum>();
492     limit = g.nodes()->create<luci::CircleConst>();
493
494     max->dtype(loco::DataType::FLOAT32);
495     limit->dtype(loco::DataType::FLOAT32);
496
497     max->shape({1, 16, 4, 4});
498     limit->shape({});
499
500     limit->size<loco::DataType::FLOAT32>(1);
501     limit->at<loco::DataType::FLOAT32>(0) = 100;
502
503     max->x(input);
504     max->y(limit);
505
506     max->name("max");
507     limit->name("limit");
508
509     return max;
510   }
511
512 public:
513   luci::CircleMaximum *max = nullptr;
514   luci::CircleConst *limit = nullptr;
515 };
516
517 class MaximumNonConstGraph final : public SimpleGraph
518 {
519 protected:
520   loco::Node *insertGraphBody(loco::Node *input) override
521   {
522     max = g.nodes()->create<luci::CircleMaximum>();
523     max->dtype(loco::DataType::FLOAT32);
524     max->shape({1, 16, 4, 4});
525
526     max->x(input);
527     max->y(input);
528
529     max->name("max");
530
531     return max;
532   }
533
534 public:
535   luci::CircleMaximum *max = nullptr;
536 };
537
538 static constexpr std::initializer_list<uint32_t> kDefaultShape = {1, 16, 1, 1};
539
540 class MeanGraph final : public SimpleGraph
541 {
542 protected:
543   loco::Node *insertGraphBody(loco::Node *input) override
544   {
545     mean = g.nodes()->create<luci::CircleMean>();
546     rindices = g.nodes()->create<luci::CircleConst>();
547
548     mean->dtype(loco::DataType::FLOAT32);
549     rindices->dtype(loco::DataType::S32);
550
551     mean->shape(_shape);
552     rindices->shape({static_cast<uint32_t>(_axes.size())});
553
554     rindices->size<loco::DataType::S32>(_axes.size());
555     for (uint32_t i = 0; i < _axes.size(); ++i)
556     {
557       rindices->at<loco::DataType::S32>(i) = _axes[i];
558     }
559
560     mean->input(input);
561     mean->reduction_indices(rindices);
562     mean->keep_dims(_keep_dims);
563
564     mean->name("mean");
565     rindices->name("rindices");
566
567     return mean;
568   }
569
570 public:
571   void keep_dims(bool val) { _keep_dims = val; }
572   void axes(std::vector<int32_t> val) { _axes = val; }
573   void shape(std::initializer_list<uint32_t> val) { _shape = val; }
574
575 public:
576   luci::CircleMean *mean = nullptr;
577   luci::CircleConst *rindices = nullptr;
578
579 private:
580   bool _keep_dims = true;
581   std::vector<int32_t> _axes = {2, 3};
582   std::initializer_list<uint32_t> _shape = kDefaultShape;
583 };
584
585 class MinimumGraph final : public SimpleGraph
586 {
587 protected:
588   loco::Node *insertGraphBody(loco::Node *input) override
589   {
590     min = g.nodes()->create<luci::CircleMinimum>();
591     limit = g.nodes()->create<luci::CircleConst>();
592
593     min->dtype(loco::DataType::FLOAT32);
594     limit->dtype(loco::DataType::FLOAT32);
595
596     min->shape({1, 16, 4, 4});
597     limit->shape({});
598
599     limit->size<loco::DataType::FLOAT32>(1);
600     limit->at<loco::DataType::FLOAT32>(0) = 100;
601
602     min->x(input);
603     min->y(limit);
604
605     min->name("min");
606     limit->name("limit");
607
608     return min;
609   }
610
611 public:
612   luci::CircleMinimum *min = nullptr;
613   luci::CircleConst *limit = nullptr;
614 };
615
616 class MulGraph final : public SimpleGraph
617 {
618 protected:
619   loco::Node *insertGraphBody(loco::Node *input) override
620   {
621     mul = g.nodes()->create<luci::CircleMul>();
622     multiplier = g.nodes()->create<luci::CircleConst>();
623
624     mul->dtype(loco::DataType::FLOAT32);
625     multiplier->dtype(loco::DataType::FLOAT32);
626
627     uint32_t channel_size = 16;
628     mul->shape({1, channel_size, 4, 4});
629     multiplier->shape({1, channel_size, 1, 1});
630
631     multiplier->size<loco::DataType::FLOAT32>(channel_size);
632     for (uint32_t i = 0; i < channel_size; i++)
633     {
634       multiplier->at<loco::DataType::FLOAT32>(i) = i;
635     }
636
637     mul->x(input);
638     mul->y(multiplier);
639
640     mul->name("mul");
641     multiplier->name("multiplier");
642
643     return mul;
644   }
645
646 public:
647   void update_const_shape_to_nchw(void)
648   {
649     uint32_t channel_size = 16;
650     multiplier->shape({1, channel_size, 4, 4});
651
652     multiplier->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
653     for (uint32_t i = 0; i < channel_size; i++)
654     {
655       multiplier->at<loco::DataType::FLOAT32>(i) = i;
656     }
657   }
658
659 public:
660   luci::CircleMul *mul = nullptr;
661   luci::CircleConst *multiplier = nullptr;
662 };
663
664 class MulScalarGraph final : public SimpleGraph
665 {
666 protected:
667   loco::Node *insertGraphBody(loco::Node *input) override
668   {
669     mul = g.nodes()->create<luci::CircleMul>();
670     multiplier = g.nodes()->create<luci::CircleConst>();
671
672     mul->dtype(loco::DataType::FLOAT32);
673     multiplier->dtype(loco::DataType::FLOAT32);
674
675     uint32_t channel_size = 16;
676     mul->shape({1, channel_size, 4, 4});
677     multiplier->shape({1});
678
679     multiplier->size<loco::DataType::FLOAT32>(1);
680     multiplier->at<loco::DataType::FLOAT32>(0) = 2;
681
682     mul->x(input);
683     mul->y(multiplier);
684
685     mul->name("mul");
686     multiplier->name("multiplier");
687
688     return mul;
689   }
690
691 public:
692   luci::CircleMul *mul = nullptr;
693   luci::CircleConst *multiplier = nullptr;
694 };
695
696 class MulBothNormGraph final : public SimpleGraph
697 {
698 protected:
699   loco::Node *insertGraphBody(loco::Node *input) override
700   {
701     mul = g.nodes()->create<luci::CircleMul>();
702
703     mul->dtype(loco::DataType::FLOAT32);
704
705     uint32_t channel_size = 16;
706     mul->shape({1, channel_size, 4, 4});
707
708     mul->x(input);
709     mul->y(input);
710
711     mul->name("mul");
712
713     return mul;
714   }
715
716 public:
717   luci::CircleMul *mul = nullptr;
718 };
719
720 class NegGraph final : public SimpleGraph
721 {
722 protected:
723   loco::Node *insertGraphBody(loco::Node *input) override
724   {
725     neg = g.nodes()->create<luci::CircleNeg>();
726     neg->x(input);
727     neg->name("neg");
728
729     return neg;
730   }
731
732 public:
733   luci::CircleNeg *neg = nullptr;
734 };
735
736 class PadGraph final : public SimpleGraph
737 {
738 protected:
739   loco::Node *insertGraphBody(loco::Node *input) override
740   {
741     pad = g.nodes()->create<luci::CirclePad>();
742     paddings = g.nodes()->create<luci::CircleConst>();
743
744     pad->dtype(loco::DataType::FLOAT32);
745     paddings->dtype(loco::DataType::S32);
746
747     uint32_t channel_size = 16;
748     pad->shape({1, channel_size, 4, 4});
749     paddings->shape({4, 2});
750
751     // paddings data (NCHW)
752     // [[0,0], [0,0], [1,1], [2,2]]
753     paddings->size<loco::DataType::S32>(8);
754     for (uint32_t dim = 0; dim < 4; dim++)
755     {
756       for (uint32_t i = 0; i < 2; i++)
757       {
758         int32_t data = 0;
759
760         if (dim == 2)
761           data = 1;
762         else if (dim == 3)
763           data = 2;
764
765         paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
766       }
767     }
768
769     pad->input(input);
770     pad->paddings(paddings);
771
772     pad->name("pad");
773     paddings->name("paddings");
774
775     return pad;
776   }
777
778 public:
779   luci::CirclePad *pad = nullptr;
780   luci::CircleConst *paddings = nullptr;
781 };
782
783 class PadV2Graph final : public SimpleGraph
784 {
785 protected:
786   loco::Node *insertGraphBody(loco::Node *input) override
787   {
788     pad = g.nodes()->create<luci::CirclePadV2>();
789     paddings = g.nodes()->create<luci::CircleConst>();
790     const_value = g.nodes()->create<luci::CircleConst>();
791
792     pad->dtype(loco::DataType::FLOAT32);
793     paddings->dtype(loco::DataType::S32);
794     const_value->dtype(loco::DataType::FLOAT32);
795
796     uint32_t channel_size = 16;
797     pad->shape({1, channel_size, 4, 4});
798     paddings->shape({4, 2});
799     const_value->shape({1});
800
801     // paddings data (NCHW)
802     // [[0,0], [0,0], [1,1], [2,2]]
803     paddings->size<loco::DataType::S32>(8);
804     for (uint32_t dim = 0; dim < 4; dim++)
805     {
806       for (uint32_t i = 0; i < 2; i++)
807       {
808         int32_t data = 0;
809
810         if (dim == 2)
811           data = 1;
812         else if (dim == 3)
813           data = 2;
814
815         paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
816       }
817     }
818
819     const_value->size<loco::DataType::FLOAT32>(1);
820     const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
821
822     pad->input(input);
823     pad->paddings(paddings);
824     pad->constant_values(paddings);
825
826     pad->name("padV2");
827     paddings->name("paddings");
828     const_value->name("constant_values");
829
830     return pad;
831   }
832
833 public:
834   luci::CirclePadV2 *pad = nullptr;
835   luci::CircleConst *paddings = nullptr;
836   luci::CircleConst *const_value = nullptr;
837 };
838
839 class ReduceMaxGraph final : public SimpleGraph
840 {
841 protected:
842   loco::Node *insertGraphBody(loco::Node *input) override
843   {
844     rm = g.nodes()->create<luci::CircleReduceMax>();
845     rindices = g.nodes()->create<luci::CircleConst>();
846
847     rm->dtype(loco::DataType::FLOAT32);
848     rindices->dtype(loco::DataType::S32);
849
850     rm->shape(_shape);
851     rindices->shape({static_cast<uint32_t>(_axes.size())});
852
853     rindices->size<loco::DataType::S32>(_axes.size());
854     for (uint32_t i = 0; i < _axes.size(); ++i)
855     {
856       rindices->at<loco::DataType::S32>(i) = _axes[i];
857     }
858
859     rm->input(input);
860     rm->reduction_indices(rindices);
861     rm->keep_dims(_keep_dims);
862
863     rm->name("reduce_max");
864     rindices->name("rindices");
865
866     return rm;
867   }
868
869 public:
870   void keep_dims(bool val) { _keep_dims = val; }
871   void axes(std::vector<int32_t> val) { _axes = val; }
872   void shape(std::initializer_list<uint32_t> val) { _shape = val; }
873
874 public:
875   luci::CircleReduceMax *rm = nullptr;
876   luci::CircleConst *rindices = nullptr;
877
878 private:
879   bool _keep_dims = true;
880   std::vector<int32_t> _axes = {2, 3};
881   std::initializer_list<uint32_t> _shape = kDefaultShape;
882 };
883
884 class ReduceMinGraph final : public SimpleGraph
885 {
886 protected:
887   loco::Node *insertGraphBody(loco::Node *input) override
888   {
889     rm = g.nodes()->create<luci::CircleReduceMin>();
890     rindices = g.nodes()->create<luci::CircleConst>();
891
892     rm->dtype(loco::DataType::FLOAT32);
893     rindices->dtype(loco::DataType::S32);
894
895     rm->shape(_shape);
896     rindices->shape({static_cast<uint32_t>(_axes.size())});
897
898     rindices->size<loco::DataType::S32>(_axes.size());
899     for (uint32_t i = 0; i < _axes.size(); ++i)
900     {
901       rindices->at<loco::DataType::S32>(i) = _axes[i];
902     }
903
904     rm->input(input);
905     rm->reduction_indices(rindices);
906     rm->keep_dims(_keep_dims);
907
908     rm->name("reduce_max");
909     rindices->name("rindices");
910
911     return rm;
912   }
913
914 public:
915   void keep_dims(bool val) { _keep_dims = val; }
916   void axes(std::vector<int32_t> val) { _axes = val; }
917   void shape(std::initializer_list<uint32_t> val) { _shape = val; }
918
919 public:
920   luci::CircleReduceMin *rm = nullptr;
921   luci::CircleConst *rindices = nullptr;
922
923 private:
924   bool _keep_dims = true;
925   std::vector<int32_t> _axes = {2, 3};
926   std::initializer_list<uint32_t> _shape = kDefaultShape;
927 };
928
929 class ReluGraph final : public SimpleGraph
930 {
931 protected:
932   loco::Node *insertGraphBody(loco::Node *input) override
933   {
934     relu = g.nodes()->create<luci::CircleRelu>();
935     relu->features(input);
936     relu->name("Relu");
937
938     return relu;
939   }
940
941 public:
942   luci::CircleRelu *relu = nullptr;
943 };
944
945 class Relu6Graph final : public SimpleGraph
946 {
947 protected:
948   loco::Node *insertGraphBody(loco::Node *input) override
949   {
950     relu6 = g.nodes()->create<luci::CircleRelu6>();
951     relu6->features(input);
952     relu6->name("relu6");
953
954     return relu6;
955   }
956
957 public:
958   luci::CircleRelu6 *relu6 = nullptr;
959 };
960
961 class RsqrtGraph final : public SimpleGraph
962 {
963 protected:
964   loco::Node *insertGraphBody(loco::Node *input) override
965   {
966     rsqrt = g.nodes()->create<luci::CircleRsqrt>();
967     rsqrt->x(input);
968     rsqrt->name("rsqrt");
969
970     return rsqrt;
971   }
972
973 public:
974   luci::CircleRsqrt *rsqrt = nullptr;
975 };
976
977 class SplitVGraphlet
978 {
979 public:
980   SplitVGraphlet() = default;
981
982 public:
983   void init(loco::Graph *g)
984   {
985     // CircleCustom(SplitV)
986     _splitv = g->nodes()->create<luci::CircleSplitV>();
987     _splitv->shape({1, 2, 2, 192});
988     _splitv->dtype(loco::DataType::FLOAT32);
989     _splitv->name("splitv");
990
991     // CircleConst
992     auto size_splits = g->nodes()->create<luci::CircleConst>();
993     size_splits->dtype(loco::DataType::S32);
994     size_splits->shape({3});
995     size_splits->size<loco::DataType::S32>(3);
996     size_splits->at<loco::DataType::S32>(0) = 32;
997     size_splits->at<loco::DataType::S32>(1) = 32;
998     size_splits->at<loco::DataType::S32>(2) = 128;
999
1000     // CircleConst
1001     auto split_dim = g->nodes()->create<luci::CircleConst>();
1002     split_dim->dtype(loco::DataType::S32);
1003     split_dim->rank(0);
1004     split_dim->size<loco::DataType::S32>(1);
1005     split_dim->scalar<loco::DataType::S32>() = 3;
1006
1007     _splitv->size_splits(size_splits);
1008     _splitv->split_dim(split_dim);
1009     _splitv->num_split(3);
1010
1011     // CircleSplitVOut
1012     _splitv_out1 = g->nodes()->create<luci::CircleSplitVOut>();
1013     _splitv_out1->shape({1, 2, 2, 32});
1014     _splitv_out1->dtype(loco::DataType::FLOAT32);
1015     _splitv_out1->index(0);
1016     _splitv_out1->input(_splitv);
1017     _splitv_out1->name("splitv_out1");
1018
1019     // CircleSplitVOut
1020     _splitv_out2 = g->nodes()->create<luci::CircleSplitVOut>();
1021     _splitv_out2->shape({1, 2, 2, 32});
1022     _splitv_out2->dtype(loco::DataType::FLOAT32);
1023     _splitv_out2->index(1);
1024     _splitv_out2->input(_splitv);
1025     _splitv_out2->name("splitv_out2");
1026
1027     // CircleSplitVOut
1028     _splitv_out3 = g->nodes()->create<luci::CircleSplitVOut>();
1029     _splitv_out3->shape({1, 2, 2, 128});
1030     _splitv_out3->dtype(loco::DataType::FLOAT32);
1031     _splitv_out3->index(2);
1032     _splitv_out3->input(_splitv);
1033     _splitv_out3->name("splitv_out3");
1034   }
1035
1036 public:
1037   luci::CircleSplitV *splitv() { return _splitv; }
1038
1039 protected:
1040   luci::CircleSplitV *_splitv = nullptr;
1041   luci::CircleSplitVOut *_splitv_out1 = nullptr;
1042   luci::CircleSplitVOut *_splitv_out2 = nullptr;
1043   luci::CircleSplitVOut *_splitv_out3 = nullptr;
1044 };
1045
1046 class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
1047 {
1048 public:
1049   SplitVGraph() = default;
1050
1051   void init(void)
1052   {
1053     TestIGraphlet::init(g(), {1, 2, 2, 192});
1054     TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
1055     SplitVGraphlet::init(g());
1056
1057     // connect graph
1058     _splitv->input(input());
1059
1060     output(0)->from(_splitv_out1);
1061     output(1)->from(_splitv_out2);
1062     output(2)->from(_splitv_out3);
1063   }
1064 };
1065
1066 class SquaredDifferenceGraph final : public SimpleGraph
1067 {
1068 protected:
1069   loco::Node *insertGraphBody(loco::Node *input) override
1070   {
1071     sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
1072     sqdiff->x(input);
1073     sqdiff->y(input);
1074     sqdiff->name("sqdiff");
1075
1076     return sqdiff;
1077   }
1078
1079 public:
1080   luci::CircleSquaredDifference *sqdiff = nullptr;
1081 };
1082
1083 class SubGraph final : public SimpleGraph
1084 {
1085 protected:
1086   loco::Node *insertGraphBody(loco::Node *input) override
1087   {
1088     sub = g.nodes()->create<luci::CircleSub>();
1089     beta = g.nodes()->create<luci::CircleConst>();
1090
1091     sub->dtype(loco::DataType::FLOAT32);
1092     beta->dtype(loco::DataType::FLOAT32);
1093
1094     uint32_t channel_size = 16;
1095     sub->shape({1, channel_size, 4, 4});
1096     beta->shape({1, channel_size, 1, 1});
1097
1098     beta->size<loco::DataType::FLOAT32>(channel_size);
1099     for (uint32_t i = 0; i < channel_size; i++)
1100     {
1101       beta->at<loco::DataType::FLOAT32>(i) = i;
1102     }
1103
1104     sub->x(input);
1105     sub->y(beta);
1106
1107     sub->name("sub");
1108     beta->name("beta");
1109
1110     return sub;
1111   }
1112
1113 public:
1114   void update_const_shape_to_nchw(void)
1115   {
1116     uint32_t channel_size = 16;
1117     beta->shape({1, channel_size, 4, 4});
1118
1119     beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
1120     for (uint32_t i = 0; i < channel_size; i++)
1121     {
1122       beta->at<loco::DataType::FLOAT32>(i) = i;
1123     }
1124   }
1125
1126 public:
1127   luci::CircleSub *sub = nullptr;
1128   luci::CircleConst *beta = nullptr;
1129 };
1130
1131 class SubScalarGraph final : public SimpleGraph
1132 {
1133 protected:
1134   loco::Node *insertGraphBody(loco::Node *input) override
1135   {
1136     sub = g.nodes()->create<luci::CircleSub>();
1137     beta = g.nodes()->create<luci::CircleConst>();
1138
1139     sub->dtype(loco::DataType::FLOAT32);
1140     beta->dtype(loco::DataType::FLOAT32);
1141
1142     uint32_t channel_size = 16;
1143     sub->shape({1, channel_size, 4, 4});
1144     beta->shape({1});
1145
1146     beta->size<loco::DataType::FLOAT32>(1);
1147     beta->at<loco::DataType::FLOAT32>(0) = 5;
1148
1149     sub->x(beta);
1150     sub->y(input);
1151
1152     sub->name("sub");
1153     beta->name("beta");
1154
1155     return sub;
1156   }
1157
1158 public:
1159   luci::CircleSub *sub = nullptr;
1160   luci::CircleConst *beta = nullptr;
1161 };
1162
1163 void check_pre_trans(loco::Node *node)
1164 {
1165   auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
1166   EXPECT_NE(nullptr, pre_trans);
1167   auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm());
1168   EXPECT_NE(nullptr, pre_trans_perm);
1169   EXPECT_EQ(1, pre_trans_perm->rank());
1170   EXPECT_EQ(4, pre_trans_perm->dim(0).value());
1171   EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype());
1172   EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0));
1173   EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1));
1174   EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2));
1175   EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3));
1176 }
1177
1178 void check_post_trans(loco::Node *node)
1179 {
1180   auto post_trans = dynamic_cast<luci::CircleTranspose *>(node);
1181   EXPECT_NE(nullptr, post_trans);
1182   auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm());
1183   EXPECT_NE(nullptr, post_trans_perm);
1184   EXPECT_EQ(1, post_trans_perm->rank());
1185   EXPECT_EQ(4, post_trans_perm->dim(0).value());
1186   EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype());
1187   EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0));
1188   EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1));
1189   EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2));
1190   EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3));
1191 }
1192
1193 void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output)
1194 {
1195   logo::Phase phase;
1196
1197   // Default passes.
1198   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
1199
1200   // Pass to test
1201   phase.emplace_back(
1202     std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
1203
1204   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
1205   phase_runner.run(phase);
1206 }
1207
1208 } // namespace
1209
1210 TEST(ConvertNCHWToNHWCPassTest, name)
1211 {
1212   luci::ConvertNCHWToNHWCPass pass(false, false);
1213   auto const name = pass.name();
1214   ASSERT_NE(nullptr, name);
1215 }
1216
1217 TEST(ConvertNCHWToNHWC, Add)
1218 {
1219   AddGraph g;
1220   g.init();
1221
1222   run_phase(&g.g, false, false);
1223
1224   auto input_succs = loco::succs(g.input);
1225   EXPECT_EQ(1, input_succs.size());
1226   check_post_trans(*input_succs.begin());
1227
1228   check_pre_trans(g.add->x());
1229
1230   auto add_succs = loco::succs(g.add);
1231   EXPECT_EQ(1, add_succs.size());
1232   check_post_trans(*add_succs.begin());
1233
1234   uint32_t channel_size = 16;
1235   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1236   EXPECT_NE(nullptr, new_beta);
1237   EXPECT_EQ(4, new_beta->rank());
1238   EXPECT_EQ(1, new_beta->dim(0).value());
1239   EXPECT_EQ(1, new_beta->dim(1).value());
1240   EXPECT_EQ(1, new_beta->dim(2).value());
1241   EXPECT_EQ(channel_size, new_beta->dim(3).value());
1242
1243   check_pre_trans(g.output->from());
1244 }
1245
1246 TEST(ConvertNCHWToNHWC, Add_NCHW_const)
1247 {
1248   AddGraph g;
1249   g.init();
1250   g.update_const_shape_to_nchw();
1251
1252   run_phase(&g.g, false, false);
1253
1254   check_pre_trans(g.add->x());
1255
1256   auto add_succs = loco::succs(g.add);
1257   EXPECT_EQ(1, add_succs.size());
1258   check_post_trans(*add_succs.begin());
1259
1260   uint32_t channel_size = 16;
1261   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1262   EXPECT_NE(nullptr, new_beta);
1263   EXPECT_EQ(4, new_beta->rank());
1264   EXPECT_EQ(1, new_beta->dim(0).value());
1265   EXPECT_EQ(4, new_beta->dim(1).value());
1266   EXPECT_EQ(4, new_beta->dim(2).value());
1267   EXPECT_EQ(channel_size, new_beta->dim(3).value());
1268 }
1269
1270 TEST(ConvertNCHWToNHWC, NHWC_Relu)
1271 {
1272   // Relu is already NHWC, so it should not be converted
1273   // i.e., the graph is not changed
1274   NHWCReluGraph g;
1275   g.init();
1276
1277   run_phase(&g.g, false, false);
1278
1279   EXPECT_EQ(g.pre_reshape, g.relu->features());
1280
1281   auto relu_succs = loco::succs(g.relu);
1282   EXPECT_EQ(1, relu_succs.size());
1283   EXPECT_EQ(g.post_reshape, *relu_succs.begin());
1284 }
1285
1286 TEST(ConvertNCHWToNHWC, AddScalar)
1287 {
1288   AddScalarGraph g;
1289   g.init();
1290
1291   run_phase(&g.g, false, false);
1292
1293   auto input_succs = loco::succs(g.input);
1294   EXPECT_EQ(1, input_succs.size());
1295   check_post_trans(*input_succs.begin());
1296
1297   check_pre_trans(g.add->x());
1298
1299   auto add_succs = loco::succs(g.add);
1300   EXPECT_EQ(1, add_succs.size());
1301   check_post_trans(*add_succs.begin());
1302
1303   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1304   EXPECT_NE(nullptr, new_beta);
1305   EXPECT_EQ(4, new_beta->rank());
1306   EXPECT_EQ(1, new_beta->dim(0).value());
1307   EXPECT_EQ(1, new_beta->dim(1).value());
1308   EXPECT_EQ(1, new_beta->dim(2).value());
1309   EXPECT_EQ(1, new_beta->dim(3).value());
1310
1311   check_pre_trans(g.output->from());
1312 }
1313
1314 TEST(ConvertNCHWToNHWC, Concatenation)
1315 {
1316   ConcatenationGraph g;
1317   g.init();
1318
1319   run_phase(&g.g, true, true);
1320
1321   check_pre_trans(g.concat->values(0));
1322   check_pre_trans(g.concat->values(1));
1323
1324   auto concat_succs = loco::succs(g.concat);
1325   EXPECT_EQ(1, concat_succs.size());
1326   check_post_trans(*concat_succs.begin());
1327
1328   // Check concat shape, axis
1329   EXPECT_EQ(1, g.concat->dim(0).value());
1330   EXPECT_EQ(4, g.concat->dim(1).value());
1331   EXPECT_EQ(4, g.concat->dim(2).value());
1332   EXPECT_EQ(32, g.concat->dim(3).value());
1333   EXPECT_EQ(3, g.concat->axis());
1334 }
1335
1336 TEST(ConvertNCHWToNHWC, Elu)
1337 {
1338   EluGraph g;
1339   g.init();
1340
1341   run_phase(&g.g, true, true);
1342
1343   check_pre_trans(g.elu->features());
1344
1345   auto elu_succs = loco::succs(g.elu);
1346   EXPECT_EQ(1, elu_succs.size());
1347   check_post_trans(*elu_succs.begin());
1348
1349   // Check elu shape
1350   EXPECT_EQ(1, g.elu->dim(0).value());
1351   EXPECT_EQ(4, g.elu->dim(1).value());
1352   EXPECT_EQ(4, g.elu->dim(2).value());
1353   EXPECT_EQ(16, g.elu->dim(3).value());
1354 }
1355
1356 TEST(ConvertNCHWToNHWC, LeakyRelu)
1357 {
1358   LeakyReluGraph g;
1359   g.init();
1360
1361   run_phase(&g.g, true, true);
1362
1363   check_pre_trans(g.leakyrelu->features());
1364
1365   auto leakyrelu_succs = loco::succs(g.leakyrelu);
1366   EXPECT_EQ(1, leakyrelu_succs.size());
1367   check_post_trans(*leakyrelu_succs.begin());
1368
1369   // Check leakyrelu shape
1370   EXPECT_EQ(1, g.leakyrelu->dim(0).value());
1371   EXPECT_EQ(4, g.leakyrelu->dim(1).value());
1372   EXPECT_EQ(4, g.leakyrelu->dim(2).value());
1373   EXPECT_EQ(16, g.leakyrelu->dim(3).value());
1374 }
1375
1376 TEST(ConvertNCHWToNHWC, Logistic)
1377 {
1378   LogisticGraph g;
1379   g.init();
1380
1381   run_phase(&g.g, true, true);
1382
1383   check_pre_trans(g.logistic->x());
1384
1385   auto logistic_succs = loco::succs(g.logistic);
1386   EXPECT_EQ(1, logistic_succs.size());
1387   check_post_trans(*logistic_succs.begin());
1388
1389   // Check logistic shape
1390   EXPECT_EQ(1, g.logistic->dim(0).value());
1391   EXPECT_EQ(4, g.logistic->dim(1).value());
1392   EXPECT_EQ(4, g.logistic->dim(2).value());
1393   EXPECT_EQ(16, g.logistic->dim(3).value());
1394 }
1395
1396 TEST(ConvertNCHWToNHWC, Maximum)
1397 {
1398   MaximumGraph g;
1399   g.init();
1400
1401   run_phase(&g.g, false, false);
1402
1403   auto input_succs = loco::succs(g.input);
1404   EXPECT_EQ(1, input_succs.size());
1405   check_post_trans(*input_succs.begin());
1406
1407   check_pre_trans(g.max->x());
1408
1409   auto max_succs = loco::succs(g.max);
1410   EXPECT_EQ(1, max_succs.size());
1411   check_post_trans(*max_succs.begin());
1412
1413   check_pre_trans(g.output->from());
1414 }
1415
1416 TEST(ConvertNCHWToNHWC, Maximum_non_scalar_NEG)
1417 {
1418   MaximumGraph g;
1419   g.init();
1420
1421   g.limit->shape({3});
1422
1423   luci::ConvertNCHWToNHWCPass pass(true, true);
1424   EXPECT_FALSE(pass.run(&g.g));
1425 }
1426
1427 TEST(ConvertNCHWToNHWC, MaximumNonConst)
1428 {
1429   MaximumNonConstGraph g;
1430   g.init();
1431
1432   run_phase(&g.g, true, true);
1433
1434   check_pre_trans(g.max->x());
1435   check_pre_trans(g.max->y());
1436
1437   auto max_succs = loco::succs(g.max);
1438   EXPECT_EQ(1, max_succs.size());
1439   check_post_trans(*max_succs.begin());
1440 }
1441
1442 TEST(ConvertNCHWToNHWC, Mean)
1443 {
1444   MeanGraph g;
1445   g.init();
1446
1447   run_phase(&g.g, false, false);
1448
1449   check_pre_trans(g.mean->input());
1450
1451   auto mean_succs = loco::succs(g.mean);
1452   EXPECT_EQ(1, mean_succs.size());
1453   check_post_trans(*mean_succs.begin());
1454
1455   auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1456   EXPECT_NE(nullptr, new_rindices);
1457   EXPECT_EQ(1, new_rindices->rank());
1458   EXPECT_EQ(2, new_rindices->dim(0).value());
1459   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1460   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1461   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1462 }
1463
1464 TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
1465 {
1466   struct TC
1467   {
1468     std::vector<int32_t> nchw_ind;
1469     std::vector<int32_t> nhwc_ind;
1470     std::initializer_list<uint32_t> shape;
1471     bool needs_transpose = false;
1472   };
1473
1474   uint32_t n = 1;
1475   uint32_t c = 16;
1476   uint32_t h = 4;
1477   uint32_t w = 4;
1478
1479   std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true},       {{1}, {3}, {n, h, w}, false},
1480                              {{2}, {1}, {n, c, w}, true},       {{3}, {2}, {n, c, h}, true},
1481                              {{0, 1}, {0, 3}, {h, w}, false},   {{0, 2}, {0, 1}, {c, w}, true},
1482                              {{0, 3}, {0, 2}, {c, h}, true},    {{1, 2}, {3, 1}, {n, w}, false},
1483                              {{1, 3}, {3, 2}, {n, h}, false},   {{2, 3}, {1, 2}, {n, c}, false},
1484                              {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1485
1486   for (auto &tc : test_cases)
1487   {
1488     MeanGraph g;
1489     g.keep_dims(false);
1490     g.axes(tc.nchw_ind);
1491     g.shape(tc.shape);
1492     g.init();
1493
1494     run_phase(&g.g, false, true);
1495
1496     check_pre_trans(g.mean->input());
1497
1498     auto mean_succs = loco::succs(g.mean);
1499     EXPECT_EQ(1, mean_succs.size());
1500     if (tc.needs_transpose)
1501     {
1502       EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
1503     }
1504     else
1505     {
1506       EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
1507     }
1508
1509     auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1510     EXPECT_NE(nullptr, new_rindices);
1511     EXPECT_EQ(1, new_rindices->rank());
1512     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1513     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1514     for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1515     {
1516       EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1517     }
1518   }
1519 }
1520
1521 TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
1522 {
1523   loco::Graph g;
1524   auto input = g.nodes()->create<luci::CircleInput>();
1525   auto output = g.nodes()->create<luci::CircleOutput>();
1526   input->name("input");
1527   output->name("output");
1528
1529   auto graph_input = g.inputs()->create();
1530   input->index(graph_input->index());
1531   auto graph_output = g.outputs()->create();
1532   output->index(graph_output->index());
1533
1534   graph_input->dtype(loco::DataType::FLOAT32);
1535   input->dtype(loco::DataType::FLOAT32);
1536   output->dtype(loco::DataType::FLOAT32);
1537   graph_output->dtype(loco::DataType::FLOAT32);
1538
1539   uint32_t channel_size = 16;
1540   graph_input->shape({channel_size, 4, 4});
1541   input->shape({channel_size, 4, 4});
1542   output->shape({channel_size});
1543   graph_output->shape({channel_size});
1544
1545   auto mean = g.nodes()->create<luci::CircleMean>();
1546   auto rindices = g.nodes()->create<luci::CircleConst>();
1547
1548   mean->dtype(loco::DataType::FLOAT32);
1549   rindices->dtype(loco::DataType::S32);
1550
1551   mean->shape({channel_size});
1552   rindices->shape({2});
1553
1554   rindices->size<loco::DataType::S32>(2);
1555   rindices->at<loco::DataType::S32>(0) = 1;
1556   rindices->at<loco::DataType::S32>(1) = 2;
1557
1558   mean->input(input);
1559   mean->reduction_indices(rindices);
1560   mean->keep_dims(false);
1561
1562   mean->name("mean");
1563   rindices->name("rindices");
1564
1565   output->from(mean);
1566
1567   run_phase(&g, true, true);
1568
1569   auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
1570   EXPECT_NE(nullptr, new_rindices);
1571   EXPECT_EQ(1, new_rindices->rank());
1572   EXPECT_EQ(2, new_rindices->dim(0).value());
1573   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1574   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1575   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1576 }
1577
1578 TEST(ConvertNCHWToNHWC, Minimum)
1579 {
1580   MinimumGraph g;
1581   g.init();
1582
1583   run_phase(&g.g, false, false);
1584
1585   auto input_succs = loco::succs(g.input);
1586   EXPECT_EQ(1, input_succs.size());
1587   check_post_trans(*input_succs.begin());
1588
1589   check_pre_trans(g.min->x());
1590
1591   auto min_succs = loco::succs(g.min);
1592   EXPECT_EQ(1, min_succs.size());
1593   check_post_trans(*min_succs.begin());
1594
1595   check_pre_trans(g.output->from());
1596 }
1597
1598 TEST(ConvertNCHWToNHWC, Minimum_non_scalar_NEG)
1599 {
1600   MinimumGraph g;
1601   g.init();
1602
1603   g.limit->shape({3});
1604
1605   luci::ConvertNCHWToNHWCPass pass(true, true);
1606   EXPECT_FALSE(pass.run(&g.g));
1607 }
1608
1609 TEST(ConvertNCHWToNHWC, Mul)
1610 {
1611   MulGraph g;
1612   g.init();
1613
1614   run_phase(&g.g, false, false);
1615
1616   auto input_succs = loco::succs(g.input);
1617   EXPECT_EQ(1, input_succs.size());
1618   check_post_trans(*input_succs.begin());
1619
1620   check_pre_trans(g.mul->x());
1621
1622   auto mul_succs = loco::succs(g.mul);
1623   EXPECT_EQ(1, mul_succs.size());
1624   check_post_trans(*mul_succs.begin());
1625
1626   uint32_t channel_size = 16;
1627   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1628   EXPECT_NE(nullptr, new_multiplier);
1629   EXPECT_EQ(4, new_multiplier->rank());
1630   EXPECT_EQ(1, new_multiplier->dim(0).value());
1631   EXPECT_EQ(1, new_multiplier->dim(1).value());
1632   EXPECT_EQ(1, new_multiplier->dim(2).value());
1633   EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1634
1635   check_pre_trans(g.output->from());
1636 }
1637
1638 TEST(ConvertNCHWToNHWC, Mul_NCHW_const)
1639 {
1640   MulGraph g;
1641   g.init();
1642   g.update_const_shape_to_nchw();
1643
1644   run_phase(&g.g, false, false);
1645
1646   check_pre_trans(g.mul->x());
1647
1648   auto mul_succs = loco::succs(g.mul);
1649   EXPECT_EQ(1, mul_succs.size());
1650   check_post_trans(*mul_succs.begin());
1651
1652   uint32_t channel_size = 16;
1653   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1654   EXPECT_NE(nullptr, new_multiplier);
1655   EXPECT_EQ(4, new_multiplier->rank());
1656   EXPECT_EQ(1, new_multiplier->dim(0).value());
1657   EXPECT_EQ(4, new_multiplier->dim(1).value());
1658   EXPECT_EQ(4, new_multiplier->dim(2).value());
1659   EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1660 }
1661
1662 TEST(ConvertNCHWToNHWC, MulScalar)
1663 {
1664   MulScalarGraph g;
1665   g.init();
1666
1667   run_phase(&g.g, false, false);
1668
1669   auto input_succs = loco::succs(g.input);
1670   EXPECT_EQ(1, input_succs.size());
1671   check_post_trans(*input_succs.begin());
1672
1673   check_pre_trans(g.mul->x());
1674
1675   auto mul_succs = loco::succs(g.mul);
1676   EXPECT_EQ(1, mul_succs.size());
1677   check_post_trans(*mul_succs.begin());
1678
1679   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1680   EXPECT_NE(nullptr, new_multiplier);
1681   EXPECT_EQ(4, new_multiplier->rank());
1682   EXPECT_EQ(1, new_multiplier->dim(0).value());
1683   EXPECT_EQ(1, new_multiplier->dim(1).value());
1684   EXPECT_EQ(1, new_multiplier->dim(2).value());
1685   EXPECT_EQ(1, new_multiplier->dim(3).value());
1686
1687   check_pre_trans(g.output->from());
1688 }
1689
1690 TEST(ConvertNCHWToNHWC, MulBothNorm)
1691 {
1692   MulBothNormGraph g;
1693   g.init();
1694
1695   run_phase(&g.g, false, false);
1696
1697   auto input_succs = loco::succs(g.input);
1698   EXPECT_EQ(1, input_succs.size());
1699   check_post_trans(*input_succs.begin());
1700
1701   check_pre_trans(g.mul->x());
1702   check_pre_trans(g.mul->y());
1703
1704   auto mul_succs = loco::succs(g.mul);
1705   EXPECT_EQ(1, mul_succs.size());
1706   check_post_trans(*mul_succs.begin());
1707
1708   check_pre_trans(g.output->from());
1709 }
1710
1711 TEST(ConvertNCHWToNHWC, Neg)
1712 {
1713   NegGraph g;
1714   g.init();
1715
1716   run_phase(&g.g, true, true);
1717
1718   check_pre_trans(g.neg->x());
1719
1720   auto neg_succs = loco::succs(g.neg);
1721   EXPECT_EQ(1, neg_succs.size());
1722   check_post_trans(*neg_succs.begin());
1723
1724   // Check leakyrelu shape
1725   EXPECT_EQ(1, g.neg->dim(0).value());
1726   EXPECT_EQ(4, g.neg->dim(1).value());
1727   EXPECT_EQ(4, g.neg->dim(2).value());
1728   EXPECT_EQ(16, g.neg->dim(3).value());
1729 }
1730
1731 TEST(ConvertNCHWToNHWC, Pad)
1732 {
1733   PadGraph g;
1734   g.init();
1735
1736   run_phase(&g.g, false, false);
1737
1738   auto input_succs = loco::succs(g.input);
1739   EXPECT_EQ(1, input_succs.size());
1740   check_post_trans(*input_succs.begin());
1741
1742   check_pre_trans(g.pad->input());
1743
1744   auto pad_succs = loco::succs(g.pad);
1745   EXPECT_EQ(1, pad_succs.size());
1746   check_post_trans(*pad_succs.begin());
1747
1748   auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1749   EXPECT_NE(nullptr, new_paddings);
1750   EXPECT_EQ(2, new_paddings->rank());
1751   EXPECT_EQ(4, new_paddings->dim(0).value());
1752   EXPECT_EQ(2, new_paddings->dim(1).value());
1753   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1754   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1755   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1756   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1757   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1758   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1759   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1760   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1761
1762   check_pre_trans(g.output->from());
1763 }
1764
1765 TEST(ConvertNCHWToNHWC, PadV2)
1766 {
1767   PadV2Graph g;
1768   g.init();
1769
1770   run_phase(&g.g, false, false);
1771
1772   check_pre_trans(g.pad->input());
1773
1774   auto pad_succs = loco::succs(g.pad);
1775   EXPECT_EQ(1, pad_succs.size());
1776   check_post_trans(*pad_succs.begin());
1777
1778   auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1779   EXPECT_NE(nullptr, new_paddings);
1780   EXPECT_EQ(2, new_paddings->rank());
1781   EXPECT_EQ(4, new_paddings->dim(0).value());
1782   EXPECT_EQ(2, new_paddings->dim(1).value());
1783   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1784   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1785   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1786   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1787   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1788   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1789   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1790   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1791 }
1792
1793 TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
1794 {
1795   AddGraph g;
1796   g.init();
1797
1798   // Unknown shape
1799   g.input->dim(0).unset();
1800   g.add->dim(0).unset();
1801   g.output->dim(0).unset();
1802
1803   luci::ConvertNCHWToNHWCPass pass(false, false);
1804   EXPECT_EQ(false, pass.run(&g.g));
1805 }
1806
1807 TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
1808 {
1809   // Preserve input
1810   {
1811     AddGraph g;
1812     g.init();
1813
1814     run_phase(&g.g, true, false);
1815
1816     // Check input shape
1817     EXPECT_EQ(1, g.input->dim(0).value());
1818     EXPECT_EQ(16, g.input->dim(1).value());
1819     EXPECT_EQ(4, g.input->dim(2).value());
1820     EXPECT_EQ(4, g.input->dim(3).value());
1821
1822     // Check output shape
1823     EXPECT_EQ(1, g.output->dim(0).value());
1824     EXPECT_EQ(4, g.output->dim(1).value());
1825     EXPECT_EQ(4, g.output->dim(2).value());
1826     EXPECT_EQ(16, g.output->dim(3).value());
1827   }
1828
1829   // Preserve output
1830   {
1831     AddGraph g;
1832     g.init();
1833
1834     run_phase(&g.g, false, true);
1835
1836     // Check input shape
1837     EXPECT_EQ(1, g.input->dim(0).value());
1838     EXPECT_EQ(4, g.input->dim(1).value());
1839     EXPECT_EQ(4, g.input->dim(2).value());
1840     EXPECT_EQ(16, g.input->dim(3).value());
1841
1842     // Check output shape
1843     EXPECT_EQ(1, g.output->dim(0).value());
1844     EXPECT_EQ(16, g.output->dim(1).value());
1845     EXPECT_EQ(4, g.output->dim(2).value());
1846     EXPECT_EQ(4, g.output->dim(3).value());
1847   }
1848
1849   // Preserve both input and output
1850   {
1851     AddGraph g;
1852     g.init();
1853
1854     run_phase(&g.g, true, true);
1855
1856     // Check input shape
1857     EXPECT_EQ(1, g.input->dim(0).value());
1858     EXPECT_EQ(16, g.input->dim(1).value());
1859     EXPECT_EQ(4, g.input->dim(2).value());
1860     EXPECT_EQ(4, g.input->dim(3).value());
1861
1862     // Check output shape
1863     EXPECT_EQ(1, g.output->dim(0).value());
1864     EXPECT_EQ(16, g.output->dim(1).value());
1865     EXPECT_EQ(4, g.output->dim(2).value());
1866     EXPECT_EQ(4, g.output->dim(3).value());
1867   }
1868 }
1869
1870 TEST(ConvertNCHWToNHWC, ReduceMax)
1871 {
1872   ReduceMaxGraph g;
1873   g.init();
1874
1875   run_phase(&g.g, false, false);
1876
1877   check_pre_trans(g.rm->input());
1878
1879   auto rm_succs = loco::succs(g.rm);
1880   EXPECT_EQ(1, rm_succs.size());
1881   check_post_trans(*rm_succs.begin());
1882
1883   auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1884   EXPECT_NE(nullptr, new_rindices);
1885   EXPECT_EQ(1, new_rindices->rank());
1886   EXPECT_EQ(2, new_rindices->dim(0).value());
1887   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1888   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1889   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1890 }
1891
1892 TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false)
1893 {
1894   struct TC
1895   {
1896     std::vector<int32_t> nchw_ind;
1897     std::vector<int32_t> nhwc_ind;
1898     std::initializer_list<uint32_t> shape;
1899     bool needs_transpose = false;
1900   };
1901
1902   uint32_t n = 1;
1903   uint32_t c = 16;
1904   uint32_t h = 4;
1905   uint32_t w = 4;
1906
1907   std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true},       {{1}, {3}, {n, h, w}, false},
1908                              {{2}, {1}, {n, c, w}, true},       {{3}, {2}, {n, c, h}, true},
1909                              {{0, 1}, {0, 3}, {h, w}, false},   {{0, 2}, {0, 1}, {c, w}, true},
1910                              {{0, 3}, {0, 2}, {c, h}, true},    {{1, 2}, {3, 1}, {n, w}, false},
1911                              {{1, 3}, {3, 2}, {n, h}, false},   {{2, 3}, {1, 2}, {n, c}, false},
1912                              {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1913
1914   for (auto &tc : test_cases)
1915   {
1916     ReduceMaxGraph g;
1917     g.keep_dims(false);
1918     g.axes(tc.nchw_ind);
1919     g.shape(tc.shape);
1920     g.init();
1921
1922     run_phase(&g.g, true, true);
1923
1924     check_pre_trans(g.rm->input());
1925
1926     auto rm_succs = loco::succs(g.rm);
1927     EXPECT_EQ(1, rm_succs.size());
1928     if (tc.needs_transpose)
1929     {
1930       EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
1931     }
1932     else
1933     {
1934       EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
1935     }
1936
1937     auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1938     EXPECT_NE(nullptr, new_rindices);
1939     EXPECT_EQ(1, new_rindices->rank());
1940     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1941     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1942     for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1943     {
1944       EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1945     }
1946   }
1947 }
1948
1949 TEST(ConvertNCHWToNHWC, ReduceMin)
1950 {
1951   ReduceMinGraph g;
1952   g.init();
1953
1954   run_phase(&g.g, true, true);
1955
1956   check_pre_trans(g.rm->input());
1957
1958   auto rm_succs = loco::succs(g.rm);
1959   EXPECT_EQ(1, rm_succs.size());
1960   check_post_trans(*rm_succs.begin());
1961
1962   auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1963   EXPECT_NE(nullptr, new_rindices);
1964   EXPECT_EQ(1, new_rindices->rank());
1965   EXPECT_EQ(2, new_rindices->dim(0).value());
1966   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1967   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1968   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1969 }
1970
1971 TEST(ConvertNCHWToNHWC, ReduceMin_keep_dims_false)
1972 {
1973   struct TC
1974   {
1975     std::vector<int32_t> nchw_ind;
1976     std::vector<int32_t> nhwc_ind;
1977     std::initializer_list<uint32_t> shape;
1978     bool needs_transpose = false;
1979   };
1980
1981   uint32_t n = 1;
1982   uint32_t c = 16;
1983   uint32_t h = 4;
1984   uint32_t w = 4;
1985
1986   std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true},       {{1}, {3}, {n, h, w}, false},
1987                              {{2}, {1}, {n, c, w}, true},       {{3}, {2}, {n, c, h}, true},
1988                              {{0, 1}, {0, 3}, {h, w}, false},   {{0, 2}, {0, 1}, {c, w}, true},
1989                              {{0, 3}, {0, 2}, {c, h}, true},    {{1, 2}, {3, 1}, {n, w}, false},
1990                              {{1, 3}, {3, 2}, {n, h}, false},   {{2, 3}, {1, 2}, {n, c}, false},
1991                              {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1992
1993   for (auto &tc : test_cases)
1994   {
1995     ReduceMinGraph g;
1996     g.keep_dims(false);
1997     g.axes(tc.nchw_ind);
1998     g.shape(tc.shape);
1999     g.init();
2000
2001     run_phase(&g.g, true, true);
2002
2003     check_pre_trans(g.rm->input());
2004
2005     auto rm_succs = loco::succs(g.rm);
2006     EXPECT_EQ(1, rm_succs.size());
2007     if (tc.needs_transpose)
2008     {
2009       EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
2010     }
2011     else
2012     {
2013       EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
2014     }
2015
2016     auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
2017     EXPECT_NE(nullptr, new_rindices);
2018     EXPECT_EQ(1, new_rindices->rank());
2019     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
2020     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
2021     for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
2022     {
2023       EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
2024     }
2025   }
2026 }
2027
2028 TEST(ConvertNCHWToNHWC, Relu)
2029 {
2030   ReluGraph g;
2031   g.init();
2032
2033   run_phase(&g.g, true, true);
2034
2035   check_pre_trans(g.relu->features());
2036
2037   auto relu_succs = loco::succs(g.relu);
2038   EXPECT_EQ(1, relu_succs.size());
2039   check_post_trans(*relu_succs.begin());
2040
2041   // Check relu shape
2042   EXPECT_EQ(1, g.relu->dim(0).value());
2043   EXPECT_EQ(4, g.relu->dim(1).value());
2044   EXPECT_EQ(4, g.relu->dim(2).value());
2045   EXPECT_EQ(16, g.relu->dim(3).value());
2046 }
2047
2048 TEST(ConvertNCHWToNHWC, Relu6)
2049 {
2050   Relu6Graph g;
2051   g.init();
2052
2053   run_phase(&g.g, true, true);
2054
2055   check_pre_trans(g.relu6->features());
2056
2057   auto relu6_succs = loco::succs(g.relu6);
2058   EXPECT_EQ(1, relu6_succs.size());
2059   check_post_trans(*relu6_succs.begin());
2060
2061   // Check relu6 shape
2062   EXPECT_EQ(1, g.relu6->dim(0).value());
2063   EXPECT_EQ(4, g.relu6->dim(1).value());
2064   EXPECT_EQ(4, g.relu6->dim(2).value());
2065   EXPECT_EQ(16, g.relu6->dim(3).value());
2066 }
2067
2068 TEST(ConvertNCHWToNHWC, Rsqrt)
2069 {
2070   RsqrtGraph g;
2071   g.init();
2072
2073   run_phase(&g.g, true, true);
2074
2075   check_pre_trans(g.rsqrt->x());
2076
2077   auto rsqrt_succs = loco::succs(g.rsqrt);
2078   EXPECT_EQ(1, rsqrt_succs.size());
2079   check_post_trans(*rsqrt_succs.begin());
2080
2081   // Check rsqrt shape
2082   EXPECT_EQ(1, g.rsqrt->dim(0).value());
2083   EXPECT_EQ(4, g.rsqrt->dim(1).value());
2084   EXPECT_EQ(4, g.rsqrt->dim(2).value());
2085   EXPECT_EQ(16, g.rsqrt->dim(3).value());
2086 }
2087
2088 TEST(ConvertNCHWToNHWC, SplitV)
2089 {
2090   SplitVGraph g;
2091   g.init();
2092
2093   run_phase(g.g(), true, true);
2094
2095   check_pre_trans(g.splitv()->input());
2096
2097   auto splitv_succs = loco::succs(g.splitv());
2098   for (auto svo : loco::succs(g.splitv()))
2099   {
2100     for (auto succ : loco::succs(svo))
2101     {
2102       check_post_trans(succ);
2103     }
2104   }
2105
2106   // Check splitv() shape
2107   EXPECT_EQ(1, g.splitv()->dim(0).value());
2108   EXPECT_EQ(2, g.splitv()->dim(1).value());
2109   EXPECT_EQ(192, g.splitv()->dim(2).value());
2110   EXPECT_EQ(2, g.splitv()->dim(3).value());
2111
2112   // Check axis
2113   auto axis = dynamic_cast<luci::CircleConst *>(g.splitv()->split_dim());
2114   EXPECT_NE(nullptr, axis);
2115   EXPECT_EQ(1, axis->size<loco::DataType::S32>());
2116   EXPECT_EQ(2, axis->at<loco::DataType::S32>(0));
2117 }
2118
2119 TEST(ConvertNCHWToNHWC, SquaredDifference)
2120 {
2121   SquaredDifferenceGraph g;
2122   g.init();
2123
2124   run_phase(&g.g, true, true);
2125
2126   check_pre_trans(g.sqdiff->x());
2127   check_pre_trans(g.sqdiff->y());
2128
2129   auto sqdiff_succs = loco::succs(g.sqdiff);
2130   EXPECT_EQ(1, sqdiff_succs.size());
2131   check_post_trans(*sqdiff_succs.begin());
2132 }
2133
2134 TEST(ConvertNCHWToNHWC, Sub)
2135 {
2136   SubGraph g;
2137   g.init();
2138
2139   run_phase(&g.g, false, false);
2140
2141   auto input_succs = loco::succs(g.input);
2142   EXPECT_EQ(1, input_succs.size());
2143   check_post_trans(*input_succs.begin());
2144
2145   check_pre_trans(g.sub->x());
2146
2147   auto add_succs = loco::succs(g.sub);
2148   EXPECT_EQ(1, add_succs.size());
2149   check_post_trans(*add_succs.begin());
2150
2151   uint32_t channel_size = 16;
2152   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2153   EXPECT_NE(nullptr, new_beta);
2154   EXPECT_EQ(4, new_beta->rank());
2155   EXPECT_EQ(1, new_beta->dim(0).value());
2156   EXPECT_EQ(1, new_beta->dim(1).value());
2157   EXPECT_EQ(1, new_beta->dim(2).value());
2158   EXPECT_EQ(channel_size, new_beta->dim(3).value());
2159
2160   check_pre_trans(g.output->from());
2161 }
2162
2163 TEST(ConvertNCHWToNHWC, Sub_NCHW_const)
2164 {
2165   SubGraph g;
2166   g.init();
2167   g.update_const_shape_to_nchw();
2168
2169   run_phase(&g.g, false, false);
2170
2171   check_pre_trans(g.sub->x());
2172
2173   auto sub_succs = loco::succs(g.sub);
2174   EXPECT_EQ(1, sub_succs.size());
2175   check_post_trans(*sub_succs.begin());
2176
2177   uint32_t channel_size = 16;
2178   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2179   EXPECT_NE(nullptr, new_beta);
2180   EXPECT_EQ(4, new_beta->rank());
2181   EXPECT_EQ(1, new_beta->dim(0).value());
2182   EXPECT_EQ(4, new_beta->dim(1).value());
2183   EXPECT_EQ(4, new_beta->dim(2).value());
2184   EXPECT_EQ(channel_size, new_beta->dim(3).value());
2185 }
2186
2187 TEST(ConvertNCHWToNHWC, SubScalar)
2188 {
2189   SubScalarGraph g;
2190   g.init();
2191
2192   run_phase(&g.g, false, false);
2193
2194   auto input_succs = loco::succs(g.input);
2195   EXPECT_EQ(1, input_succs.size());
2196   check_post_trans(*input_succs.begin());
2197
2198   check_pre_trans(g.sub->y());
2199
2200   auto add_succs = loco::succs(g.sub);
2201   EXPECT_EQ(1, add_succs.size());
2202   check_post_trans(*add_succs.begin());
2203
2204   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
2205   EXPECT_NE(nullptr, new_beta);
2206   EXPECT_EQ(1, new_beta->rank());
2207
2208   check_pre_trans(g.output->from());
2209 }
2210
2211 TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG)
2212 {
2213   NoPostReshapeGraph g;
2214   g.init();
2215
2216   run_phase(&g.g, true, true);
2217
2218   check_pre_trans(g.relu->features());
2219
2220   auto relu_succs = loco::succs(g.relu);
2221   EXPECT_EQ(1, relu_succs.size());
2222   check_post_trans(*relu_succs.begin());
2223 }
2224
2225 TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG)
2226 {
2227   ReluNotClosedGraph g;
2228   g.init();
2229
2230   run_phase(&g.g, true, true);
2231
2232   check_pre_trans(g.relu->features());
2233
2234   auto relu_succs = loco::succs(g.relu);
2235   EXPECT_EQ(1, relu_succs.size());
2236   check_post_trans(*relu_succs.begin());
2237 }