fd326518e4c951a24976bc0a3369709f6ed95471
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ConvertNCHWToNHWCPass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <logo/Phase.h>
18
19 #include <luci/test/TestIOGraph.h>
20
21 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
22 #include "luci/Pass/CircleShapeInferencePass.h"
23
24 #include <luci/IR/CircleNodes.h>
25
26 #include <gtest/gtest.h>
27
28 using namespace luci::test;
29
30 namespace
31 {
32
33 /**
34  *  Graph with a single Op (example: Add).
35  *
36  *  BEFORE
37  *  - All Ops including Input/Output are NCHW.
38  *
39  *             [Input] [beta]
40  *                |  /
41  *              [Add]
42  *                |
43  *             [Output]
44  *
45  *  AFTER
46  *  - All Ops including Input/Output are NHWC.
47  *
48  *             [Input]
49  *                |
50  *         [Transpose]
51  *                |
52  *        [Transpose] [beta]
53  *                |  /
54  *              [Add]
55  *                |
56  *         [Transpose]
57  *                |
58  *         [Transpose]
59  *                |
60  *             [Output]
61  */
62 class SimpleGraph
63 {
64 public:
65   SimpleGraph() = default;
66
67 public:
68   void init()
69   {
70     input = g.nodes()->create<luci::CircleInput>();
71     output = g.nodes()->create<luci::CircleOutput>();
72     input->name("input");
73     output->name("output");
74
75     auto graph_input = g.inputs()->create();
76     input->index(graph_input->index());
77     auto graph_output = g.outputs()->create();
78     output->index(graph_output->index());
79
80     graph_input->dtype(loco::DataType::FLOAT32);
81     input->dtype(loco::DataType::FLOAT32);
82     output->dtype(loco::DataType::FLOAT32);
83     graph_output->dtype(loco::DataType::FLOAT32);
84
85     uint32_t channel_size = 16;
86     graph_input->shape({1, channel_size, 4, 4});
87     input->shape({1, channel_size, 4, 4});
88     output->shape({1, channel_size, 4, 4});
89     graph_output->shape({1, channel_size, 4, 4});
90
91     auto graph_body = insertGraphBody(input);
92     output->from(graph_body);
93   }
94
95   virtual ~SimpleGraph() = default;
96
97 protected:
98   virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
99
100 public:
101   loco::Graph g;
102   luci::CircleInput *input = nullptr;
103   luci::CircleOutput *output = nullptr;
104 };
105
106 class AddGraph final : public SimpleGraph
107 {
108 protected:
109   loco::Node *insertGraphBody(loco::Node *input) override
110   {
111     add = g.nodes()->create<luci::CircleAdd>();
112     beta = g.nodes()->create<luci::CircleConst>();
113
114     add->dtype(loco::DataType::FLOAT32);
115     beta->dtype(loco::DataType::FLOAT32);
116
117     uint32_t channel_size = 16;
118     add->shape({1, channel_size, 4, 4});
119     beta->shape({1, channel_size, 1, 1});
120
121     beta->size<loco::DataType::FLOAT32>(channel_size);
122     for (uint32_t i = 0; i < channel_size; i++)
123     {
124       beta->at<loco::DataType::FLOAT32>(i) = i;
125     }
126
127     add->x(input);
128     add->y(beta);
129
130     add->name("add");
131     beta->name("beta");
132
133     return add;
134   }
135
136 public:
137   void update_const_shape_to_nchw(void)
138   {
139     uint32_t channel_size = 16;
140     beta->shape({1, channel_size, 4, 4});
141
142     beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
143     for (uint32_t i = 0; i < channel_size; i++)
144     {
145       beta->at<loco::DataType::FLOAT32>(i) = i;
146     }
147   }
148
149 public:
150   luci::CircleAdd *add = nullptr;
151   luci::CircleConst *beta = nullptr;
152 };
153
154 class NHWCReluGraph final : public SimpleGraph
155 {
156 protected:
157   loco::Node *insertGraphBody(loco::Node *input) override
158   {
159     relu = g.nodes()->create<luci::CircleRelu>();
160     pre_reshape = g.nodes()->create<luci::CircleReshape>();
161     post_reshape = g.nodes()->create<luci::CircleReshape>();
162     pre_shape = g.nodes()->create<luci::CircleConst>();
163     post_shape = g.nodes()->create<luci::CircleConst>();
164
165     pre_shape->dtype(loco::DataType::S32);
166     post_shape->dtype(loco::DataType::S32);
167
168     uint32_t channel_size = 16;
169     auto in = loco::must_cast<luci::CircleNode *>(input);
170     in->shape({1, channel_size, 4, 4});
171     pre_shape->shape({4});
172     post_shape->shape({4});
173
174     pre_shape->size<loco::DataType::S32>(4);
175     pre_shape->at<loco::DataType::S32>(0) = 1;
176     pre_shape->at<loco::DataType::S32>(1) = 4;
177     pre_shape->at<loco::DataType::S32>(2) = 4;
178     pre_shape->at<loco::DataType::S32>(3) = channel_size;
179
180     post_shape->size<loco::DataType::S32>(4);
181     post_shape->at<loco::DataType::S32>(0) = 1;
182     post_shape->at<loco::DataType::S32>(1) = channel_size;
183     post_shape->at<loco::DataType::S32>(2) = 4;
184     post_shape->at<loco::DataType::S32>(3) = 4;
185
186     pre_reshape->tensor(input);
187     pre_reshape->shape(pre_shape);
188
189     relu->features(pre_reshape);
190
191     post_reshape->tensor(relu);
192     post_reshape->shape(post_shape);
193
194     relu->name("Relu");
195     pre_reshape->name("pre-reshape");
196     post_reshape->name("post-reshape");
197
198     return post_reshape;
199   }
200
201 public:
202   luci::CircleRelu *relu = nullptr;
203   luci::CircleReshape *pre_reshape = nullptr;
204   luci::CircleReshape *post_reshape = nullptr;
205   luci::CircleConst *pre_shape = nullptr;
206   luci::CircleConst *post_shape = nullptr;
207 };
208
209 /**
210  *  Graph with pre-Reshape but no post-Transpose/Reshape.
211  *
212  *  BEFORE
213  *             [Input]
214  *                |
215  *          [Pre-Reshape]
216  *                |
217  *              [Relu]
218  *                |
219  *             [Output]
220  *
221  *  AFTER
222  *             [Input]
223  *                |
224  *          [Pre-Reshape]
225  *                |
226  *          [Pre-Transpose]
227  *                |
228  *              [Relu]
229  *                |
230  *          [Post-Transpose]
231  *                |
232  *             [Output]
233  */
234 class NoPostReshapeGraph final : public SimpleGraph
235 {
236 protected:
237   loco::Node *insertGraphBody(loco::Node *input) override
238   {
239     relu = g.nodes()->create<luci::CircleRelu>();
240     pre_reshape = g.nodes()->create<luci::CircleReshape>();
241     pre_shape = g.nodes()->create<luci::CircleConst>();
242
243     pre_shape->dtype(loco::DataType::S32);
244
245     uint32_t channel_size = 16;
246     auto in = loco::must_cast<luci::CircleNode *>(input);
247     in->shape({1, channel_size, 4, 4});
248     pre_shape->shape({4});
249
250     pre_shape->size<loco::DataType::S32>(4);
251     pre_shape->at<loco::DataType::S32>(0) = 1;
252     pre_shape->at<loco::DataType::S32>(1) = 4;
253     pre_shape->at<loco::DataType::S32>(2) = 4;
254     pre_shape->at<loco::DataType::S32>(3) = channel_size;
255
256     pre_reshape->tensor(input);
257     pre_reshape->shape(pre_shape);
258     relu->features(pre_reshape);
259
260     relu->name("Relu");
261     pre_reshape->name("pre-reshape");
262
263     return relu;
264   }
265
266 public:
267   luci::CircleRelu *relu = nullptr;
268   luci::CircleReshape *pre_reshape = nullptr;
269   luci::CircleConst *pre_shape = nullptr;
270 };
271
272 /**
273  *  Graph with two pre-Reshapes
274  *
275  *  BEFORE
276  *             [Input]
277  *                |
278  *          [Pre-Reshape]
279  *                |
280  *              [Relu]
281  *                |
282  *          [Pre-Reshape]
283  *                |
284  *          [Post-Reshape]
285  *                |
286  *             [Output]
287  *
288  *  AFTER
289  *             [Input]
290  *                |
291  *          [Pre-Reshape]
292  *                |
293  *          [Pre-Transpose]
294  *                |
295  *              [Relu]
296  *                |
297  *          [Post-Transpose]
298  *                |
299  *          [Pre-Reshape]
300  *                |
301  *          [Post-Reshape]
302  *                |
303  *             [Output]
304  */
305 class ReluNotClosedGraph final : public SimpleGraph
306 {
307 protected:
308   loco::Node *insertGraphBody(loco::Node *input) override
309   {
310     relu = g.nodes()->create<luci::CircleRelu>();
311     pre_reshape = g.nodes()->create<luci::CircleReshape>();
312     pre_reshape_2 = g.nodes()->create<luci::CircleReshape>();
313     post_reshape = g.nodes()->create<luci::CircleReshape>();
314     pre_shape = g.nodes()->create<luci::CircleConst>();
315     pre_shape_2 = g.nodes()->create<luci::CircleConst>();
316     post_shape = g.nodes()->create<luci::CircleConst>();
317
318     pre_shape->dtype(loco::DataType::S32);
319     pre_shape_2->dtype(loco::DataType::S32);
320     post_shape->dtype(loco::DataType::S32);
321
322     uint32_t channel_size = 16;
323     auto in = loco::must_cast<luci::CircleNode *>(input);
324     in->shape({1, channel_size, 4, 4});
325     pre_shape->shape({4});
326     pre_shape_2->shape({4});
327     post_shape->shape({4});
328
329     pre_shape->size<loco::DataType::S32>(4);
330     pre_shape->at<loco::DataType::S32>(0) = 1;
331     pre_shape->at<loco::DataType::S32>(1) = 4;
332     pre_shape->at<loco::DataType::S32>(2) = 4;
333     pre_shape->at<loco::DataType::S32>(3) = channel_size;
334
335     pre_shape_2->size<loco::DataType::S32>(4);
336     pre_shape_2->at<loco::DataType::S32>(0) = 1;
337     pre_shape_2->at<loco::DataType::S32>(1) = 4;
338     pre_shape_2->at<loco::DataType::S32>(2) = channel_size;
339     pre_shape_2->at<loco::DataType::S32>(3) = 4;
340
341     post_shape->size<loco::DataType::S32>(4);
342     post_shape->at<loco::DataType::S32>(0) = 1;
343     post_shape->at<loco::DataType::S32>(1) = 4;
344     post_shape->at<loco::DataType::S32>(2) = 4;
345     post_shape->at<loco::DataType::S32>(3) = channel_size;
346
347     pre_reshape->tensor(input);
348     pre_reshape->shape(pre_shape);
349
350     relu->features(pre_reshape);
351
352     pre_reshape_2->tensor(relu);
353     pre_reshape_2->shape(pre_shape_2);
354
355     post_reshape->tensor(pre_reshape_2);
356     post_reshape->shape(post_shape);
357
358     relu->name("Relu");
359     pre_reshape->name("pre-reshape");
360     pre_reshape->name("pre-reshape-2");
361     post_reshape->name("post-reshape");
362
363     return post_reshape;
364   }
365
366 public:
367   luci::CircleRelu *relu = nullptr;
368   luci::CircleReshape *pre_reshape = nullptr;
369   luci::CircleReshape *pre_reshape_2 = nullptr;
370   luci::CircleReshape *post_reshape = nullptr;
371   luci::CircleConst *pre_shape = nullptr;
372   luci::CircleConst *pre_shape_2 = nullptr;
373   luci::CircleConst *post_shape = nullptr;
374 };
375
376 class AddScalarGraph final : public SimpleGraph
377 {
378 protected:
379   loco::Node *insertGraphBody(loco::Node *input) override
380   {
381     add = g.nodes()->create<luci::CircleAdd>();
382     beta = g.nodes()->create<luci::CircleConst>();
383
384     add->dtype(loco::DataType::FLOAT32);
385     beta->dtype(loco::DataType::FLOAT32);
386
387     uint32_t channel_size = 16;
388     add->shape({1, channel_size, 4, 4});
389     beta->shape({1});
390
391     beta->size<loco::DataType::FLOAT32>(1);
392     beta->at<loco::DataType::FLOAT32>(0) = 3.14;
393
394     add->x(input);
395     add->y(beta);
396
397     add->name("add");
398     beta->name("beta");
399
400     return add;
401   }
402
403 public:
404   luci::CircleAdd *add = nullptr;
405   luci::CircleConst *beta = nullptr;
406 };
407
408 class ConcatenationGraph final : public SimpleGraph
409 {
410 protected:
411   loco::Node *insertGraphBody(loco::Node *input) override
412   {
413     concat = g.nodes()->create<luci::CircleConcatenation>(2);
414     concat->values(0, input);
415     concat->axis(1);
416
417     input2 = g.nodes()->create<luci::CircleConst>();
418     input2->dtype(loco::DataType::FLOAT32);
419     input2->shape({1, 16, 4, 4});
420     input2->size<loco::DataType::FLOAT32>(16 * 4 * 4);
421     for (uint32_t i = 0; i < 16 * 4 * 4; i++)
422     {
423       input2->at<loco::DataType::FLOAT32>(i) = i;
424     }
425     concat->values(1, input2);
426
427     concat->name("concat");
428     input2->name("input2");
429
430     return concat;
431   }
432
433 public:
434   luci::CircleConcatenation *concat = nullptr;
435   luci::CircleConst *input2 = nullptr;
436 };
437
438 class EluGraph final : public SimpleGraph
439 {
440 protected:
441   loco::Node *insertGraphBody(loco::Node *input) override
442   {
443     elu = g.nodes()->create<luci::CircleElu>();
444     elu->features(input);
445     elu->name("elu");
446
447     return elu;
448   }
449
450 public:
451   luci::CircleElu *elu = nullptr;
452 };
453
454 class LeakyReluGraph final : public SimpleGraph
455 {
456 protected:
457   loco::Node *insertGraphBody(loco::Node *input) override
458   {
459     leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>();
460     leakyrelu->features(input);
461     leakyrelu->name("leakyrelu");
462
463     return leakyrelu;
464   }
465
466 public:
467   luci::CircleLeakyRelu *leakyrelu = nullptr;
468 };
469
470 class LogisticGraph final : public SimpleGraph
471 {
472 protected:
473   loco::Node *insertGraphBody(loco::Node *input) override
474   {
475     logistic = g.nodes()->create<luci::CircleLogistic>();
476     logistic->x(input);
477     logistic->name("logistic");
478
479     return logistic;
480   }
481
482 public:
483   luci::CircleLogistic *logistic = nullptr;
484 };
485
486 class MaximumGraph final : public SimpleGraph
487 {
488 protected:
489   loco::Node *insertGraphBody(loco::Node *input) override
490   {
491     max = g.nodes()->create<luci::CircleMaximum>();
492     limit = g.nodes()->create<luci::CircleConst>();
493
494     max->dtype(loco::DataType::FLOAT32);
495     limit->dtype(loco::DataType::FLOAT32);
496
497     max->shape({1, 16, 4, 4});
498     limit->shape({});
499
500     limit->size<loco::DataType::FLOAT32>(1);
501     limit->at<loco::DataType::FLOAT32>(0) = 100;
502
503     max->x(input);
504     max->y(limit);
505
506     max->name("max");
507     limit->name("limit");
508
509     return max;
510   }
511
512 public:
513   luci::CircleMaximum *max = nullptr;
514   luci::CircleConst *limit = nullptr;
515 };
516
517 class MaximumNonConstGraph final : public SimpleGraph
518 {
519 protected:
520   loco::Node *insertGraphBody(loco::Node *input) override
521   {
522     max = g.nodes()->create<luci::CircleMaximum>();
523     max->dtype(loco::DataType::FLOAT32);
524     max->shape({1, 16, 4, 4});
525
526     max->x(input);
527     max->y(input);
528
529     max->name("max");
530
531     return max;
532   }
533
534 public:
535   luci::CircleMaximum *max = nullptr;
536 };
537
538 class MeanGraph final : public SimpleGraph
539 {
540 protected:
541   loco::Node *insertGraphBody(loco::Node *input) override
542   {
543     mean = g.nodes()->create<luci::CircleMean>();
544     rindices = g.nodes()->create<luci::CircleConst>();
545
546     mean->dtype(loco::DataType::FLOAT32);
547     rindices->dtype(loco::DataType::S32);
548
549     mean->shape(_shape);
550     rindices->shape({static_cast<uint32_t>(_axes.size())});
551
552     rindices->size<loco::DataType::S32>(_axes.size());
553     for (uint32_t i = 0; i < _axes.size(); ++i)
554     {
555       rindices->at<loco::DataType::S32>(i) = _axes[i];
556     }
557
558     mean->input(input);
559     mean->reduction_indices(rindices);
560     mean->keep_dims(_keep_dims);
561
562     mean->name("mean");
563     rindices->name("rindices");
564
565     return mean;
566   }
567
568 public:
569   void keep_dims(bool val) { _keep_dims = val; }
570   void axes(std::vector<int32_t> val) { _axes = val; }
571   void shape(std::initializer_list<uint32_t> val) { _shape = val; }
572
573 public:
574   luci::CircleMean *mean = nullptr;
575   luci::CircleConst *rindices = nullptr;
576
577 private:
578   bool _keep_dims = true;
579   std::vector<int32_t> _axes = {2, 3};
580   std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
581 };
582
583 class MinimumGraph final : public SimpleGraph
584 {
585 protected:
586   loco::Node *insertGraphBody(loco::Node *input) override
587   {
588     min = g.nodes()->create<luci::CircleMinimum>();
589     limit = g.nodes()->create<luci::CircleConst>();
590
591     min->dtype(loco::DataType::FLOAT32);
592     limit->dtype(loco::DataType::FLOAT32);
593
594     min->shape({1, 16, 4, 4});
595     limit->shape({});
596
597     limit->size<loco::DataType::FLOAT32>(1);
598     limit->at<loco::DataType::FLOAT32>(0) = 100;
599
600     min->x(input);
601     min->y(limit);
602
603     min->name("min");
604     limit->name("limit");
605
606     return min;
607   }
608
609 public:
610   luci::CircleMinimum *min = nullptr;
611   luci::CircleConst *limit = nullptr;
612 };
613
614 class MulGraph final : public SimpleGraph
615 {
616 protected:
617   loco::Node *insertGraphBody(loco::Node *input) override
618   {
619     mul = g.nodes()->create<luci::CircleMul>();
620     multiplier = g.nodes()->create<luci::CircleConst>();
621
622     mul->dtype(loco::DataType::FLOAT32);
623     multiplier->dtype(loco::DataType::FLOAT32);
624
625     uint32_t channel_size = 16;
626     mul->shape({1, channel_size, 4, 4});
627     multiplier->shape({1, channel_size, 1, 1});
628
629     multiplier->size<loco::DataType::FLOAT32>(channel_size);
630     for (uint32_t i = 0; i < channel_size; i++)
631     {
632       multiplier->at<loco::DataType::FLOAT32>(i) = i;
633     }
634
635     mul->x(input);
636     mul->y(multiplier);
637
638     mul->name("mul");
639     multiplier->name("multiplier");
640
641     return mul;
642   }
643
644 public:
645   void update_const_shape_to_nchw(void)
646   {
647     uint32_t channel_size = 16;
648     multiplier->shape({1, channel_size, 4, 4});
649
650     multiplier->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
651     for (uint32_t i = 0; i < channel_size; i++)
652     {
653       multiplier->at<loco::DataType::FLOAT32>(i) = i;
654     }
655   }
656
657 public:
658   luci::CircleMul *mul = nullptr;
659   luci::CircleConst *multiplier = nullptr;
660 };
661
662 class MulScalarGraph final : public SimpleGraph
663 {
664 protected:
665   loco::Node *insertGraphBody(loco::Node *input) override
666   {
667     mul = g.nodes()->create<luci::CircleMul>();
668     multiplier = g.nodes()->create<luci::CircleConst>();
669
670     mul->dtype(loco::DataType::FLOAT32);
671     multiplier->dtype(loco::DataType::FLOAT32);
672
673     uint32_t channel_size = 16;
674     mul->shape({1, channel_size, 4, 4});
675     multiplier->shape({1});
676
677     multiplier->size<loco::DataType::FLOAT32>(1);
678     multiplier->at<loco::DataType::FLOAT32>(0) = 2;
679
680     mul->x(input);
681     mul->y(multiplier);
682
683     mul->name("mul");
684     multiplier->name("multiplier");
685
686     return mul;
687   }
688
689 public:
690   luci::CircleMul *mul = nullptr;
691   luci::CircleConst *multiplier = nullptr;
692 };
693
694 class MulBothNormGraph final : public SimpleGraph
695 {
696 protected:
697   loco::Node *insertGraphBody(loco::Node *input) override
698   {
699     mul = g.nodes()->create<luci::CircleMul>();
700
701     mul->dtype(loco::DataType::FLOAT32);
702
703     uint32_t channel_size = 16;
704     mul->shape({1, channel_size, 4, 4});
705
706     mul->x(input);
707     mul->y(input);
708
709     mul->name("mul");
710
711     return mul;
712   }
713
714 public:
715   luci::CircleMul *mul = nullptr;
716 };
717
718 class NegGraph final : public SimpleGraph
719 {
720 protected:
721   loco::Node *insertGraphBody(loco::Node *input) override
722   {
723     neg = g.nodes()->create<luci::CircleNeg>();
724     neg->x(input);
725     neg->name("neg");
726
727     return neg;
728   }
729
730 public:
731   luci::CircleNeg *neg = nullptr;
732 };
733
734 class PadGraph final : public SimpleGraph
735 {
736 protected:
737   loco::Node *insertGraphBody(loco::Node *input) override
738   {
739     pad = g.nodes()->create<luci::CirclePad>();
740     paddings = g.nodes()->create<luci::CircleConst>();
741
742     pad->dtype(loco::DataType::FLOAT32);
743     paddings->dtype(loco::DataType::S32);
744
745     uint32_t channel_size = 16;
746     pad->shape({1, channel_size, 4, 4});
747     paddings->shape({4, 2});
748
749     // paddings data (NCHW)
750     // [[0,0], [0,0], [1,1], [2,2]]
751     paddings->size<loco::DataType::S32>(8);
752     for (uint32_t dim = 0; dim < 4; dim++)
753     {
754       for (uint32_t i = 0; i < 2; i++)
755       {
756         int32_t data = 0;
757
758         if (dim == 2)
759           data = 1;
760         else if (dim == 3)
761           data = 2;
762
763         paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
764       }
765     }
766
767     pad->input(input);
768     pad->paddings(paddings);
769
770     pad->name("pad");
771     paddings->name("paddings");
772
773     return pad;
774   }
775
776 public:
777   luci::CirclePad *pad = nullptr;
778   luci::CircleConst *paddings = nullptr;
779 };
780
781 class PadV2Graph final : public SimpleGraph
782 {
783 protected:
784   loco::Node *insertGraphBody(loco::Node *input) override
785   {
786     pad = g.nodes()->create<luci::CirclePadV2>();
787     paddings = g.nodes()->create<luci::CircleConst>();
788     const_value = g.nodes()->create<luci::CircleConst>();
789
790     pad->dtype(loco::DataType::FLOAT32);
791     paddings->dtype(loco::DataType::S32);
792     const_value->dtype(loco::DataType::FLOAT32);
793
794     uint32_t channel_size = 16;
795     pad->shape({1, channel_size, 4, 4});
796     paddings->shape({4, 2});
797     const_value->shape({1});
798
799     // paddings data (NCHW)
800     // [[0,0], [0,0], [1,1], [2,2]]
801     paddings->size<loco::DataType::S32>(8);
802     for (uint32_t dim = 0; dim < 4; dim++)
803     {
804       for (uint32_t i = 0; i < 2; i++)
805       {
806         int32_t data = 0;
807
808         if (dim == 2)
809           data = 1;
810         else if (dim == 3)
811           data = 2;
812
813         paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
814       }
815     }
816
817     const_value->size<loco::DataType::FLOAT32>(1);
818     const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
819
820     pad->input(input);
821     pad->paddings(paddings);
822     pad->constant_values(paddings);
823
824     pad->name("padV2");
825     paddings->name("paddings");
826     const_value->name("constant_values");
827
828     return pad;
829   }
830
831 public:
832   luci::CirclePadV2 *pad = nullptr;
833   luci::CircleConst *paddings = nullptr;
834   luci::CircleConst *const_value = nullptr;
835 };
836
837 class ReduceMaxGraph final : public SimpleGraph
838 {
839 protected:
840   loco::Node *insertGraphBody(loco::Node *input) override
841   {
842     rm = g.nodes()->create<luci::CircleReduceMax>();
843     rindices = g.nodes()->create<luci::CircleConst>();
844
845     rm->dtype(loco::DataType::FLOAT32);
846     rindices->dtype(loco::DataType::S32);
847
848     rm->shape(_shape);
849     rindices->shape({static_cast<uint32_t>(_axes.size())});
850
851     rindices->size<loco::DataType::S32>(_axes.size());
852     for (uint32_t i = 0; i < _axes.size(); ++i)
853     {
854       rindices->at<loco::DataType::S32>(i) = _axes[i];
855     }
856
857     rm->input(input);
858     rm->reduction_indices(rindices);
859     rm->keep_dims(_keep_dims);
860
861     rm->name("reduce_max");
862     rindices->name("rindices");
863
864     return rm;
865   }
866
867 public:
868   void keep_dims(bool val) { _keep_dims = val; }
869   void axes(std::vector<int32_t> val) { _axes = val; }
870   void shape(std::initializer_list<uint32_t> val) { _shape = val; }
871
872 public:
873   luci::CircleReduceMax *rm = nullptr;
874   luci::CircleConst *rindices = nullptr;
875
876 private:
877   bool _keep_dims = true;
878   std::vector<int32_t> _axes = {2, 3};
879   std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
880 };
881
882 class ReduceMinGraph final : public SimpleGraph
883 {
884 protected:
885   loco::Node *insertGraphBody(loco::Node *input) override
886   {
887     rm = g.nodes()->create<luci::CircleReduceMin>();
888     rindices = g.nodes()->create<luci::CircleConst>();
889
890     rm->dtype(loco::DataType::FLOAT32);
891     rindices->dtype(loco::DataType::S32);
892
893     rm->shape(_shape);
894     rindices->shape({static_cast<uint32_t>(_axes.size())});
895
896     rindices->size<loco::DataType::S32>(_axes.size());
897     for (uint32_t i = 0; i < _axes.size(); ++i)
898     {
899       rindices->at<loco::DataType::S32>(i) = _axes[i];
900     }
901
902     rm->input(input);
903     rm->reduction_indices(rindices);
904     rm->keep_dims(_keep_dims);
905
906     rm->name("reduce_max");
907     rindices->name("rindices");
908
909     return rm;
910   }
911
912 public:
913   void keep_dims(bool val) { _keep_dims = val; }
914   void axes(std::vector<int32_t> val) { _axes = val; }
915   void shape(std::initializer_list<uint32_t> val) { _shape = val; }
916
917 public:
918   luci::CircleReduceMin *rm = nullptr;
919   luci::CircleConst *rindices = nullptr;
920
921 private:
922   bool _keep_dims = true;
923   std::vector<int32_t> _axes = {2, 3};
924   std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
925 };
926
927 class ReluGraph final : public SimpleGraph
928 {
929 protected:
930   loco::Node *insertGraphBody(loco::Node *input) override
931   {
932     relu = g.nodes()->create<luci::CircleRelu>();
933     relu->features(input);
934     relu->name("Relu");
935
936     return relu;
937   }
938
939 public:
940   luci::CircleRelu *relu = nullptr;
941 };
942
943 class Relu6Graph final : public SimpleGraph
944 {
945 protected:
946   loco::Node *insertGraphBody(loco::Node *input) override
947   {
948     relu6 = g.nodes()->create<luci::CircleRelu6>();
949     relu6->features(input);
950     relu6->name("relu6");
951
952     return relu6;
953   }
954
955 public:
956   luci::CircleRelu6 *relu6 = nullptr;
957 };
958
959 class RsqrtGraph final : public SimpleGraph
960 {
961 protected:
962   loco::Node *insertGraphBody(loco::Node *input) override
963   {
964     rsqrt = g.nodes()->create<luci::CircleRsqrt>();
965     rsqrt->x(input);
966     rsqrt->name("rsqrt");
967
968     return rsqrt;
969   }
970
971 public:
972   luci::CircleRsqrt *rsqrt = nullptr;
973 };
974
975 class SplitVGraphlet
976 {
977 public:
978   SplitVGraphlet() = default;
979
980 public:
981   void init(loco::Graph *g)
982   {
983     // CircleCustom(SplitV)
984     _splitv = g->nodes()->create<luci::CircleSplitV>();
985     _splitv->shape({1, 2, 2, 192});
986     _splitv->dtype(loco::DataType::FLOAT32);
987     _splitv->name("splitv");
988
989     // CircleConst
990     auto size_splits = g->nodes()->create<luci::CircleConst>();
991     size_splits->dtype(loco::DataType::S32);
992     size_splits->shape({3});
993     size_splits->size<loco::DataType::S32>(3);
994     size_splits->at<loco::DataType::S32>(0) = 32;
995     size_splits->at<loco::DataType::S32>(1) = 32;
996     size_splits->at<loco::DataType::S32>(2) = 128;
997
998     // CircleConst
999     auto split_dim = g->nodes()->create<luci::CircleConst>();
1000     split_dim->dtype(loco::DataType::S32);
1001     split_dim->rank(0);
1002     split_dim->size<loco::DataType::S32>(1);
1003     split_dim->scalar<loco::DataType::S32>() = 3;
1004
1005     _splitv->size_splits(size_splits);
1006     _splitv->split_dim(split_dim);
1007     _splitv->num_split(3);
1008
1009     // CircleSplitVOut
1010     _splitv_out1 = g->nodes()->create<luci::CircleSplitVOut>();
1011     _splitv_out1->shape({1, 2, 2, 32});
1012     _splitv_out1->dtype(loco::DataType::FLOAT32);
1013     _splitv_out1->index(0);
1014     _splitv_out1->input(_splitv);
1015     _splitv_out1->name("splitv_out1");
1016
1017     // CircleSplitVOut
1018     _splitv_out2 = g->nodes()->create<luci::CircleSplitVOut>();
1019     _splitv_out2->shape({1, 2, 2, 32});
1020     _splitv_out2->dtype(loco::DataType::FLOAT32);
1021     _splitv_out2->index(1);
1022     _splitv_out2->input(_splitv);
1023     _splitv_out2->name("splitv_out2");
1024
1025     // CircleSplitVOut
1026     _splitv_out3 = g->nodes()->create<luci::CircleSplitVOut>();
1027     _splitv_out3->shape({1, 2, 2, 128});
1028     _splitv_out3->dtype(loco::DataType::FLOAT32);
1029     _splitv_out3->index(2);
1030     _splitv_out3->input(_splitv);
1031     _splitv_out3->name("splitv_out3");
1032   }
1033
1034 public:
1035   luci::CircleSplitV *splitv() { return _splitv; }
1036
1037 protected:
1038   luci::CircleSplitV *_splitv = nullptr;
1039   luci::CircleSplitVOut *_splitv_out1 = nullptr;
1040   luci::CircleSplitVOut *_splitv_out2 = nullptr;
1041   luci::CircleSplitVOut *_splitv_out3 = nullptr;
1042 };
1043
1044 class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
1045 {
1046 public:
1047   SplitVGraph() = default;
1048
1049   void init(void)
1050   {
1051     TestIGraphlet::init(g(), {1, 2, 2, 192});
1052     TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
1053     SplitVGraphlet::init(g());
1054
1055     // connect graph
1056     _splitv->input(input());
1057
1058     output(0)->from(_splitv_out1);
1059     output(1)->from(_splitv_out2);
1060     output(2)->from(_splitv_out3);
1061   }
1062 };
1063
1064 class SquaredDifferenceGraph final : public SimpleGraph
1065 {
1066 protected:
1067   loco::Node *insertGraphBody(loco::Node *input) override
1068   {
1069     sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
1070     sqdiff->x(input);
1071     sqdiff->y(input);
1072     sqdiff->name("sqdiff");
1073
1074     return sqdiff;
1075   }
1076
1077 public:
1078   luci::CircleSquaredDifference *sqdiff = nullptr;
1079 };
1080
1081 class SubGraph final : public SimpleGraph
1082 {
1083 protected:
1084   loco::Node *insertGraphBody(loco::Node *input) override
1085   {
1086     sub = g.nodes()->create<luci::CircleSub>();
1087     beta = g.nodes()->create<luci::CircleConst>();
1088
1089     sub->dtype(loco::DataType::FLOAT32);
1090     beta->dtype(loco::DataType::FLOAT32);
1091
1092     uint32_t channel_size = 16;
1093     sub->shape({1, channel_size, 4, 4});
1094     beta->shape({1, channel_size, 1, 1});
1095
1096     beta->size<loco::DataType::FLOAT32>(channel_size);
1097     for (uint32_t i = 0; i < channel_size; i++)
1098     {
1099       beta->at<loco::DataType::FLOAT32>(i) = i;
1100     }
1101
1102     sub->x(input);
1103     sub->y(beta);
1104
1105     sub->name("sub");
1106     beta->name("beta");
1107
1108     return sub;
1109   }
1110
1111 public:
1112   void update_const_shape_to_nchw(void)
1113   {
1114     uint32_t channel_size = 16;
1115     beta->shape({1, channel_size, 4, 4});
1116
1117     beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
1118     for (uint32_t i = 0; i < channel_size; i++)
1119     {
1120       beta->at<loco::DataType::FLOAT32>(i) = i;
1121     }
1122   }
1123
1124 public:
1125   luci::CircleSub *sub = nullptr;
1126   luci::CircleConst *beta = nullptr;
1127 };
1128
1129 class SubScalarGraph final : public SimpleGraph
1130 {
1131 protected:
1132   loco::Node *insertGraphBody(loco::Node *input) override
1133   {
1134     sub = g.nodes()->create<luci::CircleSub>();
1135     beta = g.nodes()->create<luci::CircleConst>();
1136
1137     sub->dtype(loco::DataType::FLOAT32);
1138     beta->dtype(loco::DataType::FLOAT32);
1139
1140     uint32_t channel_size = 16;
1141     sub->shape({1, channel_size, 4, 4});
1142     beta->shape({1});
1143
1144     beta->size<loco::DataType::FLOAT32>(1);
1145     beta->at<loco::DataType::FLOAT32>(0) = 5;
1146
1147     sub->x(beta);
1148     sub->y(input);
1149
1150     sub->name("sub");
1151     beta->name("beta");
1152
1153     return sub;
1154   }
1155
1156 public:
1157   luci::CircleSub *sub = nullptr;
1158   luci::CircleConst *beta = nullptr;
1159 };
1160
1161 void check_pre_trans(loco::Node *node)
1162 {
1163   auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
1164   EXPECT_NE(nullptr, pre_trans);
1165   auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm());
1166   EXPECT_NE(nullptr, pre_trans_perm);
1167   EXPECT_EQ(1, pre_trans_perm->rank());
1168   EXPECT_EQ(4, pre_trans_perm->dim(0).value());
1169   EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype());
1170   EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0));
1171   EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1));
1172   EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2));
1173   EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3));
1174 }
1175
1176 void check_post_trans(loco::Node *node)
1177 {
1178   auto post_trans = dynamic_cast<luci::CircleTranspose *>(node);
1179   EXPECT_NE(nullptr, post_trans);
1180   auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm());
1181   EXPECT_NE(nullptr, post_trans_perm);
1182   EXPECT_EQ(1, post_trans_perm->rank());
1183   EXPECT_EQ(4, post_trans_perm->dim(0).value());
1184   EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype());
1185   EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0));
1186   EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1));
1187   EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2));
1188   EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3));
1189 }
1190
1191 void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output)
1192 {
1193   logo::Phase phase;
1194
1195   // Default passes.
1196   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
1197
1198   // Pass to test
1199   phase.emplace_back(
1200     std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
1201
1202   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
1203   phase_runner.run(phase);
1204 }
1205
1206 } // namespace
1207
1208 TEST(ConvertNCHWToNHWCPassTest, name)
1209 {
1210   luci::ConvertNCHWToNHWCPass pass(false, false);
1211   auto const name = pass.name();
1212   ASSERT_NE(nullptr, name);
1213 }
1214
1215 TEST(ConvertNCHWToNHWC, Add)
1216 {
1217   AddGraph g;
1218   g.init();
1219
1220   run_phase(&g.g, false, false);
1221
1222   auto input_succs = loco::succs(g.input);
1223   EXPECT_EQ(1, input_succs.size());
1224   check_post_trans(*input_succs.begin());
1225
1226   check_pre_trans(g.add->x());
1227
1228   auto add_succs = loco::succs(g.add);
1229   EXPECT_EQ(1, add_succs.size());
1230   check_post_trans(*add_succs.begin());
1231
1232   uint32_t channel_size = 16;
1233   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1234   EXPECT_NE(nullptr, new_beta);
1235   EXPECT_EQ(4, new_beta->rank());
1236   EXPECT_EQ(1, new_beta->dim(0).value());
1237   EXPECT_EQ(1, new_beta->dim(1).value());
1238   EXPECT_EQ(1, new_beta->dim(2).value());
1239   EXPECT_EQ(channel_size, new_beta->dim(3).value());
1240
1241   check_pre_trans(g.output->from());
1242 }
1243
1244 TEST(ConvertNCHWToNHWC, Add_NCHW_const)
1245 {
1246   AddGraph g;
1247   g.init();
1248   g.update_const_shape_to_nchw();
1249
1250   run_phase(&g.g, false, false);
1251
1252   check_pre_trans(g.add->x());
1253
1254   auto add_succs = loco::succs(g.add);
1255   EXPECT_EQ(1, add_succs.size());
1256   check_post_trans(*add_succs.begin());
1257
1258   uint32_t channel_size = 16;
1259   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1260   EXPECT_NE(nullptr, new_beta);
1261   EXPECT_EQ(4, new_beta->rank());
1262   EXPECT_EQ(1, new_beta->dim(0).value());
1263   EXPECT_EQ(4, new_beta->dim(1).value());
1264   EXPECT_EQ(4, new_beta->dim(2).value());
1265   EXPECT_EQ(channel_size, new_beta->dim(3).value());
1266 }
1267
1268 TEST(ConvertNCHWToNHWC, NHWC_Relu)
1269 {
1270   // Relu is already NHWC, so it should not be converted
1271   // i.e., the graph is not changed
1272   NHWCReluGraph g;
1273   g.init();
1274
1275   run_phase(&g.g, false, false);
1276
1277   EXPECT_EQ(g.pre_reshape, g.relu->features());
1278
1279   auto relu_succs = loco::succs(g.relu);
1280   EXPECT_EQ(1, relu_succs.size());
1281   EXPECT_EQ(g.post_reshape, *relu_succs.begin());
1282 }
1283
1284 TEST(ConvertNCHWToNHWC, AddScalar)
1285 {
1286   AddScalarGraph g;
1287   g.init();
1288
1289   run_phase(&g.g, false, false);
1290
1291   auto input_succs = loco::succs(g.input);
1292   EXPECT_EQ(1, input_succs.size());
1293   check_post_trans(*input_succs.begin());
1294
1295   check_pre_trans(g.add->x());
1296
1297   auto add_succs = loco::succs(g.add);
1298   EXPECT_EQ(1, add_succs.size());
1299   check_post_trans(*add_succs.begin());
1300
1301   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
1302   EXPECT_NE(nullptr, new_beta);
1303   EXPECT_EQ(4, new_beta->rank());
1304   EXPECT_EQ(1, new_beta->dim(0).value());
1305   EXPECT_EQ(1, new_beta->dim(1).value());
1306   EXPECT_EQ(1, new_beta->dim(2).value());
1307   EXPECT_EQ(1, new_beta->dim(3).value());
1308
1309   check_pre_trans(g.output->from());
1310 }
1311
1312 TEST(ConvertNCHWToNHWC, Concatenation)
1313 {
1314   ConcatenationGraph g;
1315   g.init();
1316
1317   run_phase(&g.g, true, true);
1318
1319   check_pre_trans(g.concat->values(0));
1320   check_pre_trans(g.concat->values(1));
1321
1322   auto concat_succs = loco::succs(g.concat);
1323   EXPECT_EQ(1, concat_succs.size());
1324   check_post_trans(*concat_succs.begin());
1325
1326   // Check concat shape, axis
1327   EXPECT_EQ(1, g.concat->dim(0).value());
1328   EXPECT_EQ(4, g.concat->dim(1).value());
1329   EXPECT_EQ(4, g.concat->dim(2).value());
1330   EXPECT_EQ(32, g.concat->dim(3).value());
1331   EXPECT_EQ(3, g.concat->axis());
1332 }
1333
1334 TEST(ConvertNCHWToNHWC, Elu)
1335 {
1336   EluGraph g;
1337   g.init();
1338
1339   run_phase(&g.g, true, true);
1340
1341   check_pre_trans(g.elu->features());
1342
1343   auto elu_succs = loco::succs(g.elu);
1344   EXPECT_EQ(1, elu_succs.size());
1345   check_post_trans(*elu_succs.begin());
1346
1347   // Check elu shape
1348   EXPECT_EQ(1, g.elu->dim(0).value());
1349   EXPECT_EQ(4, g.elu->dim(1).value());
1350   EXPECT_EQ(4, g.elu->dim(2).value());
1351   EXPECT_EQ(16, g.elu->dim(3).value());
1352 }
1353
1354 TEST(ConvertNCHWToNHWC, LeakyRelu)
1355 {
1356   LeakyReluGraph g;
1357   g.init();
1358
1359   run_phase(&g.g, true, true);
1360
1361   check_pre_trans(g.leakyrelu->features());
1362
1363   auto leakyrelu_succs = loco::succs(g.leakyrelu);
1364   EXPECT_EQ(1, leakyrelu_succs.size());
1365   check_post_trans(*leakyrelu_succs.begin());
1366
1367   // Check leakyrelu shape
1368   EXPECT_EQ(1, g.leakyrelu->dim(0).value());
1369   EXPECT_EQ(4, g.leakyrelu->dim(1).value());
1370   EXPECT_EQ(4, g.leakyrelu->dim(2).value());
1371   EXPECT_EQ(16, g.leakyrelu->dim(3).value());
1372 }
1373
1374 TEST(ConvertNCHWToNHWC, Logistic)
1375 {
1376   LogisticGraph g;
1377   g.init();
1378
1379   run_phase(&g.g, true, true);
1380
1381   check_pre_trans(g.logistic->x());
1382
1383   auto logistic_succs = loco::succs(g.logistic);
1384   EXPECT_EQ(1, logistic_succs.size());
1385   check_post_trans(*logistic_succs.begin());
1386
1387   // Check logistic shape
1388   EXPECT_EQ(1, g.logistic->dim(0).value());
1389   EXPECT_EQ(4, g.logistic->dim(1).value());
1390   EXPECT_EQ(4, g.logistic->dim(2).value());
1391   EXPECT_EQ(16, g.logistic->dim(3).value());
1392 }
1393
1394 TEST(ConvertNCHWToNHWC, Maximum)
1395 {
1396   MaximumGraph g;
1397   g.init();
1398
1399   run_phase(&g.g, false, false);
1400
1401   auto input_succs = loco::succs(g.input);
1402   EXPECT_EQ(1, input_succs.size());
1403   check_post_trans(*input_succs.begin());
1404
1405   check_pre_trans(g.max->x());
1406
1407   auto max_succs = loco::succs(g.max);
1408   EXPECT_EQ(1, max_succs.size());
1409   check_post_trans(*max_succs.begin());
1410
1411   check_pre_trans(g.output->from());
1412 }
1413
1414 TEST(ConvertNCHWToNHWC, Maximum_non_scalar_NEG)
1415 {
1416   MaximumGraph g;
1417   g.init();
1418
1419   g.limit->shape({3});
1420
1421   luci::ConvertNCHWToNHWCPass pass(true, true);
1422   EXPECT_FALSE(pass.run(&g.g));
1423 }
1424
1425 TEST(ConvertNCHWToNHWC, MaximumNonConst)
1426 {
1427   MaximumNonConstGraph g;
1428   g.init();
1429
1430   run_phase(&g.g, true, true);
1431
1432   check_pre_trans(g.max->x());
1433   check_pre_trans(g.max->y());
1434
1435   auto max_succs = loco::succs(g.max);
1436   EXPECT_EQ(1, max_succs.size());
1437   check_post_trans(*max_succs.begin());
1438 }
1439
1440 TEST(ConvertNCHWToNHWC, Mean)
1441 {
1442   MeanGraph g;
1443   g.init();
1444
1445   run_phase(&g.g, false, false);
1446
1447   check_pre_trans(g.mean->input());
1448
1449   auto mean_succs = loco::succs(g.mean);
1450   EXPECT_EQ(1, mean_succs.size());
1451   check_post_trans(*mean_succs.begin());
1452
1453   auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1454   EXPECT_NE(nullptr, new_rindices);
1455   EXPECT_EQ(1, new_rindices->rank());
1456   EXPECT_EQ(2, new_rindices->dim(0).value());
1457   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1458   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1459   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1460 }
1461
1462 TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
1463 {
1464   struct TC
1465   {
1466     std::vector<int32_t> nchw_ind;
1467     std::vector<int32_t> nhwc_ind;
1468     std::initializer_list<uint32_t> shape;
1469     bool needs_transpose = false;
1470   };
1471
1472   uint32_t n = 1;
1473   uint32_t c = 16;
1474   uint32_t h = 4;
1475   uint32_t w = 4;
1476
1477   std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true},       {{1}, {3}, {n, h, w}, false},
1478                              {{2}, {1}, {n, c, w}, true},       {{3}, {2}, {n, c, h}, true},
1479                              {{0, 1}, {0, 3}, {h, w}, false},   {{0, 2}, {0, 1}, {c, w}, true},
1480                              {{0, 3}, {0, 2}, {c, h}, true},    {{1, 2}, {3, 1}, {n, w}, false},
1481                              {{1, 3}, {3, 2}, {n, h}, false},   {{2, 3}, {1, 2}, {n, c}, false},
1482                              {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1483
1484   for (auto &tc : test_cases)
1485   {
1486     MeanGraph g;
1487     g.keep_dims(false);
1488     g.axes(tc.nchw_ind);
1489     g.shape(tc.shape);
1490     g.init();
1491
1492     run_phase(&g.g, false, true);
1493
1494     check_pre_trans(g.mean->input());
1495
1496     auto mean_succs = loco::succs(g.mean);
1497     EXPECT_EQ(1, mean_succs.size());
1498     if (tc.needs_transpose)
1499     {
1500       EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
1501     }
1502     else
1503     {
1504       EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
1505     }
1506
1507     auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1508     EXPECT_NE(nullptr, new_rindices);
1509     EXPECT_EQ(1, new_rindices->rank());
1510     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1511     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1512     for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1513     {
1514       EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1515     }
1516   }
1517 }
1518
1519 TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
1520 {
1521   loco::Graph g;
1522   auto input = g.nodes()->create<luci::CircleInput>();
1523   auto output = g.nodes()->create<luci::CircleOutput>();
1524   input->name("input");
1525   output->name("output");
1526
1527   auto graph_input = g.inputs()->create();
1528   input->index(graph_input->index());
1529   auto graph_output = g.outputs()->create();
1530   output->index(graph_output->index());
1531
1532   graph_input->dtype(loco::DataType::FLOAT32);
1533   input->dtype(loco::DataType::FLOAT32);
1534   output->dtype(loco::DataType::FLOAT32);
1535   graph_output->dtype(loco::DataType::FLOAT32);
1536
1537   uint32_t channel_size = 16;
1538   graph_input->shape({channel_size, 4, 4});
1539   input->shape({channel_size, 4, 4});
1540   output->shape({channel_size});
1541   graph_output->shape({channel_size});
1542
1543   auto mean = g.nodes()->create<luci::CircleMean>();
1544   auto rindices = g.nodes()->create<luci::CircleConst>();
1545
1546   mean->dtype(loco::DataType::FLOAT32);
1547   rindices->dtype(loco::DataType::S32);
1548
1549   mean->shape({channel_size});
1550   rindices->shape({2});
1551
1552   rindices->size<loco::DataType::S32>(2);
1553   rindices->at<loco::DataType::S32>(0) = 1;
1554   rindices->at<loco::DataType::S32>(1) = 2;
1555
1556   mean->input(input);
1557   mean->reduction_indices(rindices);
1558   mean->keep_dims(false);
1559
1560   mean->name("mean");
1561   rindices->name("rindices");
1562
1563   output->from(mean);
1564
1565   run_phase(&g, true, true);
1566
1567   auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
1568   EXPECT_NE(nullptr, new_rindices);
1569   EXPECT_EQ(1, new_rindices->rank());
1570   EXPECT_EQ(2, new_rindices->dim(0).value());
1571   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1572   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1573   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1574 }
1575
1576 TEST(ConvertNCHWToNHWC, Minimum)
1577 {
1578   MinimumGraph g;
1579   g.init();
1580
1581   run_phase(&g.g, false, false);
1582
1583   auto input_succs = loco::succs(g.input);
1584   EXPECT_EQ(1, input_succs.size());
1585   check_post_trans(*input_succs.begin());
1586
1587   check_pre_trans(g.min->x());
1588
1589   auto min_succs = loco::succs(g.min);
1590   EXPECT_EQ(1, min_succs.size());
1591   check_post_trans(*min_succs.begin());
1592
1593   check_pre_trans(g.output->from());
1594 }
1595
1596 TEST(ConvertNCHWToNHWC, Minimum_non_scalar_NEG)
1597 {
1598   MinimumGraph g;
1599   g.init();
1600
1601   g.limit->shape({3});
1602
1603   luci::ConvertNCHWToNHWCPass pass(true, true);
1604   EXPECT_FALSE(pass.run(&g.g));
1605 }
1606
1607 TEST(ConvertNCHWToNHWC, Mul)
1608 {
1609   MulGraph g;
1610   g.init();
1611
1612   run_phase(&g.g, false, false);
1613
1614   auto input_succs = loco::succs(g.input);
1615   EXPECT_EQ(1, input_succs.size());
1616   check_post_trans(*input_succs.begin());
1617
1618   check_pre_trans(g.mul->x());
1619
1620   auto mul_succs = loco::succs(g.mul);
1621   EXPECT_EQ(1, mul_succs.size());
1622   check_post_trans(*mul_succs.begin());
1623
1624   uint32_t channel_size = 16;
1625   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1626   EXPECT_NE(nullptr, new_multiplier);
1627   EXPECT_EQ(4, new_multiplier->rank());
1628   EXPECT_EQ(1, new_multiplier->dim(0).value());
1629   EXPECT_EQ(1, new_multiplier->dim(1).value());
1630   EXPECT_EQ(1, new_multiplier->dim(2).value());
1631   EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1632
1633   check_pre_trans(g.output->from());
1634 }
1635
1636 TEST(ConvertNCHWToNHWC, Mul_NCHW_const)
1637 {
1638   MulGraph g;
1639   g.init();
1640   g.update_const_shape_to_nchw();
1641
1642   run_phase(&g.g, false, false);
1643
1644   check_pre_trans(g.mul->x());
1645
1646   auto mul_succs = loco::succs(g.mul);
1647   EXPECT_EQ(1, mul_succs.size());
1648   check_post_trans(*mul_succs.begin());
1649
1650   uint32_t channel_size = 16;
1651   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1652   EXPECT_NE(nullptr, new_multiplier);
1653   EXPECT_EQ(4, new_multiplier->rank());
1654   EXPECT_EQ(1, new_multiplier->dim(0).value());
1655   EXPECT_EQ(4, new_multiplier->dim(1).value());
1656   EXPECT_EQ(4, new_multiplier->dim(2).value());
1657   EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1658 }
1659
1660 TEST(ConvertNCHWToNHWC, MulScalar)
1661 {
1662   MulScalarGraph g;
1663   g.init();
1664
1665   run_phase(&g.g, false, false);
1666
1667   auto input_succs = loco::succs(g.input);
1668   EXPECT_EQ(1, input_succs.size());
1669   check_post_trans(*input_succs.begin());
1670
1671   check_pre_trans(g.mul->x());
1672
1673   auto mul_succs = loco::succs(g.mul);
1674   EXPECT_EQ(1, mul_succs.size());
1675   check_post_trans(*mul_succs.begin());
1676
1677   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1678   EXPECT_NE(nullptr, new_multiplier);
1679   EXPECT_EQ(4, new_multiplier->rank());
1680   EXPECT_EQ(1, new_multiplier->dim(0).value());
1681   EXPECT_EQ(1, new_multiplier->dim(1).value());
1682   EXPECT_EQ(1, new_multiplier->dim(2).value());
1683   EXPECT_EQ(1, new_multiplier->dim(3).value());
1684
1685   check_pre_trans(g.output->from());
1686 }
1687
1688 TEST(ConvertNCHWToNHWC, MulBothNorm)
1689 {
1690   MulBothNormGraph g;
1691   g.init();
1692
1693   run_phase(&g.g, false, false);
1694
1695   auto input_succs = loco::succs(g.input);
1696   EXPECT_EQ(1, input_succs.size());
1697   check_post_trans(*input_succs.begin());
1698
1699   check_pre_trans(g.mul->x());
1700   check_pre_trans(g.mul->y());
1701
1702   auto mul_succs = loco::succs(g.mul);
1703   EXPECT_EQ(1, mul_succs.size());
1704   check_post_trans(*mul_succs.begin());
1705
1706   check_pre_trans(g.output->from());
1707 }
1708
1709 TEST(ConvertNCHWToNHWC, Neg)
1710 {
1711   NegGraph g;
1712   g.init();
1713
1714   run_phase(&g.g, true, true);
1715
1716   check_pre_trans(g.neg->x());
1717
1718   auto neg_succs = loco::succs(g.neg);
1719   EXPECT_EQ(1, neg_succs.size());
1720   check_post_trans(*neg_succs.begin());
1721
1722   // Check leakyrelu shape
1723   EXPECT_EQ(1, g.neg->dim(0).value());
1724   EXPECT_EQ(4, g.neg->dim(1).value());
1725   EXPECT_EQ(4, g.neg->dim(2).value());
1726   EXPECT_EQ(16, g.neg->dim(3).value());
1727 }
1728
1729 TEST(ConvertNCHWToNHWC, Pad)
1730 {
1731   PadGraph g;
1732   g.init();
1733
1734   run_phase(&g.g, false, false);
1735
1736   auto input_succs = loco::succs(g.input);
1737   EXPECT_EQ(1, input_succs.size());
1738   check_post_trans(*input_succs.begin());
1739
1740   check_pre_trans(g.pad->input());
1741
1742   auto pad_succs = loco::succs(g.pad);
1743   EXPECT_EQ(1, pad_succs.size());
1744   check_post_trans(*pad_succs.begin());
1745
1746   auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1747   EXPECT_NE(nullptr, new_paddings);
1748   EXPECT_EQ(2, new_paddings->rank());
1749   EXPECT_EQ(4, new_paddings->dim(0).value());
1750   EXPECT_EQ(2, new_paddings->dim(1).value());
1751   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1752   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1753   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1754   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1755   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1756   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1757   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1758   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1759
1760   check_pre_trans(g.output->from());
1761 }
1762
1763 TEST(ConvertNCHWToNHWC, PadV2)
1764 {
1765   PadV2Graph g;
1766   g.init();
1767
1768   run_phase(&g.g, false, false);
1769
1770   check_pre_trans(g.pad->input());
1771
1772   auto pad_succs = loco::succs(g.pad);
1773   EXPECT_EQ(1, pad_succs.size());
1774   check_post_trans(*pad_succs.begin());
1775
1776   auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1777   EXPECT_NE(nullptr, new_paddings);
1778   EXPECT_EQ(2, new_paddings->rank());
1779   EXPECT_EQ(4, new_paddings->dim(0).value());
1780   EXPECT_EQ(2, new_paddings->dim(1).value());
1781   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1782   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1783   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1784   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1785   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1786   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1787   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1788   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1789 }
1790
1791 TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
1792 {
1793   AddGraph g;
1794   g.init();
1795
1796   // Unknown shape
1797   g.input->dim(0).unset();
1798   g.add->dim(0).unset();
1799   g.output->dim(0).unset();
1800
1801   luci::ConvertNCHWToNHWCPass pass(false, false);
1802   EXPECT_EQ(false, pass.run(&g.g));
1803 }
1804
1805 TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
1806 {
1807   // Preserve input
1808   {
1809     AddGraph g;
1810     g.init();
1811
1812     run_phase(&g.g, true, false);
1813
1814     // Check input shape
1815     EXPECT_EQ(1, g.input->dim(0).value());
1816     EXPECT_EQ(16, g.input->dim(1).value());
1817     EXPECT_EQ(4, g.input->dim(2).value());
1818     EXPECT_EQ(4, g.input->dim(3).value());
1819
1820     // Check output shape
1821     EXPECT_EQ(1, g.output->dim(0).value());
1822     EXPECT_EQ(4, g.output->dim(1).value());
1823     EXPECT_EQ(4, g.output->dim(2).value());
1824     EXPECT_EQ(16, g.output->dim(3).value());
1825   }
1826
1827   // Preserve output
1828   {
1829     AddGraph g;
1830     g.init();
1831
1832     run_phase(&g.g, false, true);
1833
1834     // Check input shape
1835     EXPECT_EQ(1, g.input->dim(0).value());
1836     EXPECT_EQ(4, g.input->dim(1).value());
1837     EXPECT_EQ(4, g.input->dim(2).value());
1838     EXPECT_EQ(16, g.input->dim(3).value());
1839
1840     // Check output shape
1841     EXPECT_EQ(1, g.output->dim(0).value());
1842     EXPECT_EQ(16, g.output->dim(1).value());
1843     EXPECT_EQ(4, g.output->dim(2).value());
1844     EXPECT_EQ(4, g.output->dim(3).value());
1845   }
1846
1847   // Preserve both input and output
1848   {
1849     AddGraph g;
1850     g.init();
1851
1852     run_phase(&g.g, true, true);
1853
1854     // Check input shape
1855     EXPECT_EQ(1, g.input->dim(0).value());
1856     EXPECT_EQ(16, g.input->dim(1).value());
1857     EXPECT_EQ(4, g.input->dim(2).value());
1858     EXPECT_EQ(4, g.input->dim(3).value());
1859
1860     // Check output shape
1861     EXPECT_EQ(1, g.output->dim(0).value());
1862     EXPECT_EQ(16, g.output->dim(1).value());
1863     EXPECT_EQ(4, g.output->dim(2).value());
1864     EXPECT_EQ(4, g.output->dim(3).value());
1865   }
1866 }
1867
1868 TEST(ConvertNCHWToNHWC, ReduceMax)
1869 {
1870   ReduceMaxGraph g;
1871   g.init();
1872
1873   run_phase(&g.g, false, false);
1874
1875   check_pre_trans(g.rm->input());
1876
1877   auto rm_succs = loco::succs(g.rm);
1878   EXPECT_EQ(1, rm_succs.size());
1879   check_post_trans(*rm_succs.begin());
1880
1881   auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1882   EXPECT_NE(nullptr, new_rindices);
1883   EXPECT_EQ(1, new_rindices->rank());
1884   EXPECT_EQ(2, new_rindices->dim(0).value());
1885   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1886   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1887   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1888 }
1889
1890 TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false)
1891 {
1892   struct TC
1893   {
1894     std::vector<int32_t> nchw_ind;
1895     std::vector<int32_t> nhwc_ind;
1896     std::initializer_list<uint32_t> shape;
1897     bool needs_transpose = false;
1898   };
1899
1900   uint32_t n = 1;
1901   uint32_t c = 16;
1902   uint32_t h = 4;
1903   uint32_t w = 4;
1904
1905   std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true},       {{1}, {3}, {n, h, w}, false},
1906                              {{2}, {1}, {n, c, w}, true},       {{3}, {2}, {n, c, h}, true},
1907                              {{0, 1}, {0, 3}, {h, w}, false},   {{0, 2}, {0, 1}, {c, w}, true},
1908                              {{0, 3}, {0, 2}, {c, h}, true},    {{1, 2}, {3, 1}, {n, w}, false},
1909                              {{1, 3}, {3, 2}, {n, h}, false},   {{2, 3}, {1, 2}, {n, c}, false},
1910                              {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1911
1912   for (auto &tc : test_cases)
1913   {
1914     ReduceMaxGraph g;
1915     g.keep_dims(false);
1916     g.axes(tc.nchw_ind);
1917     g.shape(tc.shape);
1918     g.init();
1919
1920     run_phase(&g.g, true, true);
1921
1922     check_pre_trans(g.rm->input());
1923
1924     auto rm_succs = loco::succs(g.rm);
1925     EXPECT_EQ(1, rm_succs.size());
1926     if (tc.needs_transpose)
1927     {
1928       EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
1929     }
1930     else
1931     {
1932       EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
1933     }
1934
1935     auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1936     EXPECT_NE(nullptr, new_rindices);
1937     EXPECT_EQ(1, new_rindices->rank());
1938     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1939     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1940     for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1941     {
1942       EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1943     }
1944   }
1945 }
1946
1947 TEST(ConvertNCHWToNHWC, ReduceMin)
1948 {
1949   ReduceMinGraph g;
1950   g.init();
1951
1952   run_phase(&g.g, true, true);
1953
1954   check_pre_trans(g.rm->input());
1955
1956   auto rm_succs = loco::succs(g.rm);
1957   EXPECT_EQ(1, rm_succs.size());
1958   check_post_trans(*rm_succs.begin());
1959
1960   auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
1961   EXPECT_NE(nullptr, new_rindices);
1962   EXPECT_EQ(1, new_rindices->rank());
1963   EXPECT_EQ(2, new_rindices->dim(0).value());
1964   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1965   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1966   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1967 }
1968
1969 TEST(ConvertNCHWToNHWC, ReduceMin_keep_dims_false)
1970 {
1971   struct TC
1972   {
1973     std::vector<int32_t> nchw_ind;
1974     std::vector<int32_t> nhwc_ind;
1975     std::initializer_list<uint32_t> shape;
1976     bool needs_transpose = false;
1977   };
1978
1979   uint32_t n = 1;
1980   uint32_t c = 16;
1981   uint32_t h = 4;
1982   uint32_t w = 4;
1983
1984   std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true},       {{1}, {3}, {n, h, w}, false},
1985                              {{2}, {1}, {n, c, w}, true},       {{3}, {2}, {n, c, h}, true},
1986                              {{0, 1}, {0, 3}, {h, w}, false},   {{0, 2}, {0, 1}, {c, w}, true},
1987                              {{0, 3}, {0, 2}, {c, h}, true},    {{1, 2}, {3, 1}, {n, w}, false},
1988                              {{1, 3}, {3, 2}, {n, h}, false},   {{2, 3}, {1, 2}, {n, c}, false},
1989                              {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1990
1991   for (auto &tc : test_cases)
1992   {
1993     ReduceMinGraph g;
1994     g.keep_dims(false);
1995     g.axes(tc.nchw_ind);
1996     g.shape(tc.shape);
1997     g.init();
1998
1999     run_phase(&g.g, true, true);
2000
2001     check_pre_trans(g.rm->input());
2002
2003     auto rm_succs = loco::succs(g.rm);
2004     EXPECT_EQ(1, rm_succs.size());
2005     if (tc.needs_transpose)
2006     {
2007       EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
2008     }
2009     else
2010     {
2011       EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
2012     }
2013
2014     auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
2015     EXPECT_NE(nullptr, new_rindices);
2016     EXPECT_EQ(1, new_rindices->rank());
2017     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
2018     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
2019     for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
2020     {
2021       EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
2022     }
2023   }
2024 }
2025
2026 TEST(ConvertNCHWToNHWC, Relu)
2027 {
2028   ReluGraph g;
2029   g.init();
2030
2031   run_phase(&g.g, true, true);
2032
2033   check_pre_trans(g.relu->features());
2034
2035   auto relu_succs = loco::succs(g.relu);
2036   EXPECT_EQ(1, relu_succs.size());
2037   check_post_trans(*relu_succs.begin());
2038
2039   // Check relu shape
2040   EXPECT_EQ(1, g.relu->dim(0).value());
2041   EXPECT_EQ(4, g.relu->dim(1).value());
2042   EXPECT_EQ(4, g.relu->dim(2).value());
2043   EXPECT_EQ(16, g.relu->dim(3).value());
2044 }
2045
2046 TEST(ConvertNCHWToNHWC, Relu6)
2047 {
2048   Relu6Graph g;
2049   g.init();
2050
2051   run_phase(&g.g, true, true);
2052
2053   check_pre_trans(g.relu6->features());
2054
2055   auto relu6_succs = loco::succs(g.relu6);
2056   EXPECT_EQ(1, relu6_succs.size());
2057   check_post_trans(*relu6_succs.begin());
2058
2059   // Check relu6 shape
2060   EXPECT_EQ(1, g.relu6->dim(0).value());
2061   EXPECT_EQ(4, g.relu6->dim(1).value());
2062   EXPECT_EQ(4, g.relu6->dim(2).value());
2063   EXPECT_EQ(16, g.relu6->dim(3).value());
2064 }
2065
2066 TEST(ConvertNCHWToNHWC, Rsqrt)
2067 {
2068   RsqrtGraph g;
2069   g.init();
2070
2071   run_phase(&g.g, true, true);
2072
2073   check_pre_trans(g.rsqrt->x());
2074
2075   auto rsqrt_succs = loco::succs(g.rsqrt);
2076   EXPECT_EQ(1, rsqrt_succs.size());
2077   check_post_trans(*rsqrt_succs.begin());
2078
2079   // Check rsqrt shape
2080   EXPECT_EQ(1, g.rsqrt->dim(0).value());
2081   EXPECT_EQ(4, g.rsqrt->dim(1).value());
2082   EXPECT_EQ(4, g.rsqrt->dim(2).value());
2083   EXPECT_EQ(16, g.rsqrt->dim(3).value());
2084 }
2085
2086 TEST(ConvertNCHWToNHWC, SplitV)
2087 {
2088   SplitVGraph g;
2089   g.init();
2090
2091   run_phase(g.g(), true, true);
2092
2093   check_pre_trans(g.splitv()->input());
2094
2095   auto splitv_succs = loco::succs(g.splitv());
2096   for (auto svo : loco::succs(g.splitv()))
2097   {
2098     for (auto succ : loco::succs(svo))
2099     {
2100       check_post_trans(succ);
2101     }
2102   }
2103
2104   // Check splitv() shape
2105   EXPECT_EQ(1, g.splitv()->dim(0).value());
2106   EXPECT_EQ(2, g.splitv()->dim(1).value());
2107   EXPECT_EQ(192, g.splitv()->dim(2).value());
2108   EXPECT_EQ(2, g.splitv()->dim(3).value());
2109
2110   // Check axis
2111   auto axis = dynamic_cast<luci::CircleConst *>(g.splitv()->split_dim());
2112   EXPECT_NE(nullptr, axis);
2113   EXPECT_EQ(1, axis->size<loco::DataType::S32>());
2114   EXPECT_EQ(2, axis->at<loco::DataType::S32>(0));
2115 }
2116
2117 TEST(ConvertNCHWToNHWC, SquaredDifference)
2118 {
2119   SquaredDifferenceGraph g;
2120   g.init();
2121
2122   run_phase(&g.g, true, true);
2123
2124   check_pre_trans(g.sqdiff->x());
2125   check_pre_trans(g.sqdiff->y());
2126
2127   auto sqdiff_succs = loco::succs(g.sqdiff);
2128   EXPECT_EQ(1, sqdiff_succs.size());
2129   check_post_trans(*sqdiff_succs.begin());
2130 }
2131
2132 TEST(ConvertNCHWToNHWC, Sub)
2133 {
2134   SubGraph g;
2135   g.init();
2136
2137   run_phase(&g.g, false, false);
2138
2139   auto input_succs = loco::succs(g.input);
2140   EXPECT_EQ(1, input_succs.size());
2141   check_post_trans(*input_succs.begin());
2142
2143   check_pre_trans(g.sub->x());
2144
2145   auto add_succs = loco::succs(g.sub);
2146   EXPECT_EQ(1, add_succs.size());
2147   check_post_trans(*add_succs.begin());
2148
2149   uint32_t channel_size = 16;
2150   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2151   EXPECT_NE(nullptr, new_beta);
2152   EXPECT_EQ(4, new_beta->rank());
2153   EXPECT_EQ(1, new_beta->dim(0).value());
2154   EXPECT_EQ(1, new_beta->dim(1).value());
2155   EXPECT_EQ(1, new_beta->dim(2).value());
2156   EXPECT_EQ(channel_size, new_beta->dim(3).value());
2157
2158   check_pre_trans(g.output->from());
2159 }
2160
2161 TEST(ConvertNCHWToNHWC, Sub_NCHW_const)
2162 {
2163   SubGraph g;
2164   g.init();
2165   g.update_const_shape_to_nchw();
2166
2167   run_phase(&g.g, false, false);
2168
2169   check_pre_trans(g.sub->x());
2170
2171   auto sub_succs = loco::succs(g.sub);
2172   EXPECT_EQ(1, sub_succs.size());
2173   check_post_trans(*sub_succs.begin());
2174
2175   uint32_t channel_size = 16;
2176   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
2177   EXPECT_NE(nullptr, new_beta);
2178   EXPECT_EQ(4, new_beta->rank());
2179   EXPECT_EQ(1, new_beta->dim(0).value());
2180   EXPECT_EQ(4, new_beta->dim(1).value());
2181   EXPECT_EQ(4, new_beta->dim(2).value());
2182   EXPECT_EQ(channel_size, new_beta->dim(3).value());
2183 }
2184
2185 TEST(ConvertNCHWToNHWC, SubScalar)
2186 {
2187   SubScalarGraph g;
2188   g.init();
2189
2190   run_phase(&g.g, false, false);
2191
2192   auto input_succs = loco::succs(g.input);
2193   EXPECT_EQ(1, input_succs.size());
2194   check_post_trans(*input_succs.begin());
2195
2196   check_pre_trans(g.sub->y());
2197
2198   auto add_succs = loco::succs(g.sub);
2199   EXPECT_EQ(1, add_succs.size());
2200   check_post_trans(*add_succs.begin());
2201
2202   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
2203   EXPECT_NE(nullptr, new_beta);
2204   EXPECT_EQ(1, new_beta->rank());
2205
2206   check_pre_trans(g.output->from());
2207 }
2208
2209 TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG)
2210 {
2211   NoPostReshapeGraph g;
2212   g.init();
2213
2214   run_phase(&g.g, true, true);
2215
2216   check_pre_trans(g.relu->features());
2217
2218   auto relu_succs = loco::succs(g.relu);
2219   EXPECT_EQ(1, relu_succs.size());
2220   check_post_trans(*relu_succs.begin());
2221 }
2222
2223 TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG)
2224 {
2225   ReluNotClosedGraph g;
2226   g.init();
2227
2228   run_phase(&g.g, true, true);
2229
2230   check_pre_trans(g.relu->features());
2231
2232   auto relu_succs = loco::succs(g.relu);
2233   EXPECT_EQ(1, relu_succs.size());
2234   check_post_trans(*relu_succs.begin());
2235 }