Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ConvertNCHWToNHWCPass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <logo/Phase.h>
18
19 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
20 #include "luci/Pass/CircleShapeInferencePass.h"
21
22 #include <luci/IR/CircleNodes.h>
23
24 #include <gtest/gtest.h>
25
26 namespace
27 {
28
29 /**
30  *  Graph with a single Op (example: Add).
31  *
32  *  BEFORE
33  *  - All Ops including Input/Output are NCHW.
34  *
35  *             [Input] [beta]
36  *                |  /
37  *              [Add]
38  *                |
39  *             [Output]
40  *
41  *  AFTER
42  *  - All Ops including Input/Output are NHWC.
43  *
44  *             [Input]
45  *                |
46  *         [Transpose]
47  *                |
48  *        [Transpose] [beta]
49  *                |  /
50  *              [Add]
51  *                |
52  *         [Transpose]
53  *                |
54  *         [Transpose]
55  *                |
56  *             [Output]
57  */
58 class SimpleGraph
59 {
60 public:
61   SimpleGraph() = default;
62
63 public:
64   void init()
65   {
66     input = g.nodes()->create<luci::CircleInput>();
67     output = g.nodes()->create<luci::CircleOutput>();
68     input->name("input");
69     output->name("output");
70
71     auto graph_input = g.inputs()->create();
72     input->index(graph_input->index());
73     auto graph_output = g.outputs()->create();
74     output->index(graph_output->index());
75
76     graph_input->dtype(loco::DataType::FLOAT32);
77     input->dtype(loco::DataType::FLOAT32);
78     output->dtype(loco::DataType::FLOAT32);
79     graph_output->dtype(loco::DataType::FLOAT32);
80
81     uint32_t channel_size = 16;
82     graph_input->shape({1, channel_size, 4, 4});
83     input->shape({1, channel_size, 4, 4});
84     output->shape({1, channel_size, 4, 4});
85     graph_output->shape({1, channel_size, 4, 4});
86
87     auto graph_body = insertGraphBody(input);
88     output->from(graph_body);
89   }
90
91   virtual ~SimpleGraph() = default;
92
93 protected:
94   virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
95
96 public:
97   loco::Graph g;
98   luci::CircleInput *input = nullptr;
99   luci::CircleOutput *output = nullptr;
100 };
101
102 class AddGraph final : public SimpleGraph
103 {
104 protected:
105   loco::Node *insertGraphBody(loco::Node *input) override
106   {
107     add = g.nodes()->create<luci::CircleAdd>();
108     beta = g.nodes()->create<luci::CircleConst>();
109
110     add->dtype(loco::DataType::FLOAT32);
111     beta->dtype(loco::DataType::FLOAT32);
112
113     uint32_t channel_size = 16;
114     add->shape({1, channel_size, 4, 4});
115     beta->shape({1, channel_size, 1, 1});
116
117     beta->size<loco::DataType::FLOAT32>(channel_size);
118     for (uint32_t i = 0; i < channel_size; i++)
119     {
120       beta->at<loco::DataType::FLOAT32>(i) = i;
121     }
122
123     add->x(input);
124     add->y(beta);
125
126     add->name("add");
127     beta->name("beta");
128
129     return add;
130   }
131
132 public:
133   void update_const_shape_to_nchw(void)
134   {
135     uint32_t channel_size = 16;
136     beta->shape({1, channel_size, 4, 4});
137
138     beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
139     for (uint32_t i = 0; i < channel_size; i++)
140     {
141       beta->at<loco::DataType::FLOAT32>(i) = i;
142     }
143   }
144
145 public:
146   luci::CircleAdd *add = nullptr;
147   luci::CircleConst *beta = nullptr;
148 };
149
150 class NHWCReluGraph final : public SimpleGraph
151 {
152 protected:
153   loco::Node *insertGraphBody(loco::Node *input) override
154   {
155     relu = g.nodes()->create<luci::CircleRelu>();
156     pre_reshape = g.nodes()->create<luci::CircleReshape>();
157     post_reshape = g.nodes()->create<luci::CircleReshape>();
158     pre_shape = g.nodes()->create<luci::CircleConst>();
159     post_shape = g.nodes()->create<luci::CircleConst>();
160
161     pre_shape->dtype(loco::DataType::S32);
162     post_shape->dtype(loco::DataType::S32);
163
164     uint32_t channel_size = 16;
165     auto in = loco::must_cast<luci::CircleNode *>(input);
166     in->shape({1, channel_size, 4, 4});
167     pre_shape->shape({4});
168     post_shape->shape({4});
169
170     pre_shape->size<loco::DataType::S32>(4);
171     pre_shape->at<loco::DataType::S32>(0) = 1;
172     pre_shape->at<loco::DataType::S32>(1) = 4;
173     pre_shape->at<loco::DataType::S32>(2) = 4;
174     pre_shape->at<loco::DataType::S32>(3) = channel_size;
175
176     post_shape->size<loco::DataType::S32>(4);
177     post_shape->at<loco::DataType::S32>(0) = 1;
178     post_shape->at<loco::DataType::S32>(1) = channel_size;
179     post_shape->at<loco::DataType::S32>(2) = 4;
180     post_shape->at<loco::DataType::S32>(3) = 4;
181
182     pre_reshape->tensor(input);
183     pre_reshape->shape(pre_shape);
184
185     relu->features(pre_reshape);
186
187     post_reshape->tensor(relu);
188     post_reshape->shape(post_shape);
189
190     relu->name("Relu");
191     pre_reshape->name("pre-reshape");
192     post_reshape->name("post-reshape");
193
194     return post_reshape;
195   }
196
197 public:
198   luci::CircleRelu *relu = nullptr;
199   luci::CircleReshape *pre_reshape = nullptr;
200   luci::CircleReshape *post_reshape = nullptr;
201   luci::CircleConst *pre_shape = nullptr;
202   luci::CircleConst *post_shape = nullptr;
203 };
204
205 class AddScalarGraph final : public SimpleGraph
206 {
207 protected:
208   loco::Node *insertGraphBody(loco::Node *input) override
209   {
210     add = g.nodes()->create<luci::CircleAdd>();
211     beta = g.nodes()->create<luci::CircleConst>();
212
213     add->dtype(loco::DataType::FLOAT32);
214     beta->dtype(loco::DataType::FLOAT32);
215
216     uint32_t channel_size = 16;
217     add->shape({1, channel_size, 4, 4});
218     beta->shape({1});
219
220     beta->size<loco::DataType::FLOAT32>(1);
221     beta->at<loco::DataType::FLOAT32>(0) = 3.14;
222
223     add->x(input);
224     add->y(beta);
225
226     add->name("add");
227     beta->name("beta");
228
229     return add;
230   }
231
232 public:
233   luci::CircleAdd *add = nullptr;
234   luci::CircleConst *beta = nullptr;
235 };
236
237 class ConcatenationGraph final : public SimpleGraph
238 {
239 protected:
240   loco::Node *insertGraphBody(loco::Node *input) override
241   {
242     concat = g.nodes()->create<luci::CircleConcatenation>(2);
243     concat->values(0, input);
244     concat->axis(1);
245
246     input2 = g.nodes()->create<luci::CircleConst>();
247     input2->dtype(loco::DataType::FLOAT32);
248     input2->shape({1, 16, 4, 4});
249     input2->size<loco::DataType::FLOAT32>(16 * 4 * 4);
250     for (uint32_t i = 0; i < 16 * 4 * 4; i++)
251     {
252       input2->at<loco::DataType::FLOAT32>(i) = i;
253     }
254     concat->values(1, input2);
255
256     concat->name("concat");
257     input2->name("input2");
258
259     return concat;
260   }
261
262 public:
263   luci::CircleConcatenation *concat = nullptr;
264   luci::CircleConst *input2 = nullptr;
265 };
266
267 class LeakyReluGraph final : public SimpleGraph
268 {
269 protected:
270   loco::Node *insertGraphBody(loco::Node *input) override
271   {
272     leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>();
273     leakyrelu->features(input);
274     leakyrelu->name("leakyrelu");
275
276     return leakyrelu;
277   }
278
279 public:
280   luci::CircleLeakyRelu *leakyrelu = nullptr;
281 };
282
283 class LogisticGraph final : public SimpleGraph
284 {
285 protected:
286   loco::Node *insertGraphBody(loco::Node *input) override
287   {
288     logistic = g.nodes()->create<luci::CircleLogistic>();
289     logistic->x(input);
290     logistic->name("logistic");
291
292     return logistic;
293   }
294
295 public:
296   luci::CircleLogistic *logistic = nullptr;
297 };
298
299 class MaximumGraph final : public SimpleGraph
300 {
301 protected:
302   loco::Node *insertGraphBody(loco::Node *input) override
303   {
304     max = g.nodes()->create<luci::CircleMaximum>();
305     limit = g.nodes()->create<luci::CircleConst>();
306
307     max->dtype(loco::DataType::FLOAT32);
308     limit->dtype(loco::DataType::FLOAT32);
309
310     max->shape({1, 16, 4, 4});
311     limit->shape({});
312
313     limit->size<loco::DataType::FLOAT32>(1);
314     limit->at<loco::DataType::FLOAT32>(0) = 100;
315
316     max->x(input);
317     max->y(limit);
318
319     max->name("max");
320     limit->name("limit");
321
322     return max;
323   }
324
325 public:
326   luci::CircleMaximum *max = nullptr;
327   luci::CircleConst *limit = nullptr;
328 };
329
330 class MeanGraph final : public SimpleGraph
331 {
332 protected:
333   loco::Node *insertGraphBody(loco::Node *input) override
334   {
335     mean = g.nodes()->create<luci::CircleMean>();
336     rindices = g.nodes()->create<luci::CircleConst>();
337
338     mean->dtype(loco::DataType::FLOAT32);
339     rindices->dtype(loco::DataType::S32);
340
341     mean->shape(_shape);
342     rindices->shape({static_cast<uint32_t>(_axes.size())});
343
344     rindices->size<loco::DataType::S32>(_axes.size());
345     for (uint32_t i = 0; i < _axes.size(); ++i)
346     {
347       rindices->at<loco::DataType::S32>(i) = _axes[i];
348     }
349
350     mean->input(input);
351     mean->reduction_indices(rindices);
352     mean->keep_dims(_keep_dims);
353
354     mean->name("mean");
355     rindices->name("rindices");
356
357     return mean;
358   }
359
360 public:
361   void keep_dims(bool val) { _keep_dims = val; }
362   void axes(std::vector<int32_t> val) { _axes = val; }
363   void shape(std::initializer_list<uint32_t> val) { _shape = val; }
364
365 public:
366   luci::CircleMean *mean = nullptr;
367   luci::CircleConst *rindices = nullptr;
368
369 private:
370   bool _keep_dims = true;
371   std::vector<int32_t> _axes = {2, 3};
372   std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
373 };
374
375 class MinimumGraph final : public SimpleGraph
376 {
377 protected:
378   loco::Node *insertGraphBody(loco::Node *input) override
379   {
380     min = g.nodes()->create<luci::CircleMinimum>();
381     limit = g.nodes()->create<luci::CircleConst>();
382
383     min->dtype(loco::DataType::FLOAT32);
384     limit->dtype(loco::DataType::FLOAT32);
385
386     min->shape({1, 16, 4, 4});
387     limit->shape({});
388
389     limit->size<loco::DataType::FLOAT32>(1);
390     limit->at<loco::DataType::FLOAT32>(0) = 100;
391
392     min->x(input);
393     min->y(limit);
394
395     min->name("min");
396     limit->name("limit");
397
398     return min;
399   }
400
401 public:
402   luci::CircleMinimum *min = nullptr;
403   luci::CircleConst *limit = nullptr;
404 };
405
406 class MulGraph final : public SimpleGraph
407 {
408 protected:
409   loco::Node *insertGraphBody(loco::Node *input) override
410   {
411     mul = g.nodes()->create<luci::CircleMul>();
412     multiplier = g.nodes()->create<luci::CircleConst>();
413
414     mul->dtype(loco::DataType::FLOAT32);
415     multiplier->dtype(loco::DataType::FLOAT32);
416
417     uint32_t channel_size = 16;
418     mul->shape({1, channel_size, 4, 4});
419     multiplier->shape({1, channel_size, 1, 1});
420
421     multiplier->size<loco::DataType::FLOAT32>(channel_size);
422     for (uint32_t i = 0; i < channel_size; i++)
423     {
424       multiplier->at<loco::DataType::FLOAT32>(i) = i;
425     }
426
427     mul->x(input);
428     mul->y(multiplier);
429
430     mul->name("mul");
431     multiplier->name("multiplier");
432
433     return mul;
434   }
435
436 public:
437   void update_const_shape_to_nchw(void)
438   {
439     uint32_t channel_size = 16;
440     multiplier->shape({1, channel_size, 4, 4});
441
442     multiplier->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
443     for (uint32_t i = 0; i < channel_size; i++)
444     {
445       multiplier->at<loco::DataType::FLOAT32>(i) = i;
446     }
447   }
448
449 public:
450   luci::CircleMul *mul = nullptr;
451   luci::CircleConst *multiplier = nullptr;
452 };
453
454 class MulScalarGraph final : public SimpleGraph
455 {
456 protected:
457   loco::Node *insertGraphBody(loco::Node *input) override
458   {
459     mul = g.nodes()->create<luci::CircleMul>();
460     multiplier = g.nodes()->create<luci::CircleConst>();
461
462     mul->dtype(loco::DataType::FLOAT32);
463     multiplier->dtype(loco::DataType::FLOAT32);
464
465     uint32_t channel_size = 16;
466     mul->shape({1, channel_size, 4, 4});
467     multiplier->shape({1});
468
469     multiplier->size<loco::DataType::FLOAT32>(1);
470     multiplier->at<loco::DataType::FLOAT32>(0) = 2;
471
472     mul->x(input);
473     mul->y(multiplier);
474
475     mul->name("mul");
476     multiplier->name("multiplier");
477
478     return mul;
479   }
480
481 public:
482   luci::CircleMul *mul = nullptr;
483   luci::CircleConst *multiplier = nullptr;
484 };
485
486 class MulBothNormGraph final : public SimpleGraph
487 {
488 protected:
489   loco::Node *insertGraphBody(loco::Node *input) override
490   {
491     mul = g.nodes()->create<luci::CircleMul>();
492
493     mul->dtype(loco::DataType::FLOAT32);
494
495     uint32_t channel_size = 16;
496     mul->shape({1, channel_size, 4, 4});
497
498     mul->x(input);
499     mul->y(input);
500
501     mul->name("mul");
502
503     return mul;
504   }
505
506 public:
507   luci::CircleMul *mul = nullptr;
508 };
509
510 class NegGraph final : public SimpleGraph
511 {
512 protected:
513   loco::Node *insertGraphBody(loco::Node *input) override
514   {
515     neg = g.nodes()->create<luci::CircleNeg>();
516     neg->x(input);
517     neg->name("neg");
518
519     return neg;
520   }
521
522 public:
523   luci::CircleNeg *neg = nullptr;
524 };
525
526 class PadGraph final : public SimpleGraph
527 {
528 protected:
529   loco::Node *insertGraphBody(loco::Node *input) override
530   {
531     pad = g.nodes()->create<luci::CirclePad>();
532     paddings = g.nodes()->create<luci::CircleConst>();
533
534     pad->dtype(loco::DataType::FLOAT32);
535     paddings->dtype(loco::DataType::S32);
536
537     uint32_t channel_size = 16;
538     pad->shape({1, channel_size, 4, 4});
539     paddings->shape({4, 2});
540
541     // paddings data (NCHW)
542     // [[0,0], [0,0], [1,1], [2,2]]
543     paddings->size<loco::DataType::S32>(8);
544     for (uint32_t dim = 0; dim < 4; dim++)
545     {
546       for (uint32_t i = 0; i < 2; i++)
547       {
548         int32_t data = 0;
549
550         if (dim == 2)
551           data = 1;
552         else if (dim == 3)
553           data = 2;
554
555         paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
556       }
557     }
558
559     pad->input(input);
560     pad->paddings(paddings);
561
562     pad->name("pad");
563     paddings->name("paddings");
564
565     return pad;
566   }
567
568 public:
569   luci::CirclePad *pad = nullptr;
570   luci::CircleConst *paddings = nullptr;
571 };
572
573 class PadV2Graph final : public SimpleGraph
574 {
575 protected:
576   loco::Node *insertGraphBody(loco::Node *input) override
577   {
578     pad = g.nodes()->create<luci::CirclePadV2>();
579     paddings = g.nodes()->create<luci::CircleConst>();
580     const_value = g.nodes()->create<luci::CircleConst>();
581
582     pad->dtype(loco::DataType::FLOAT32);
583     paddings->dtype(loco::DataType::S32);
584     const_value->dtype(loco::DataType::FLOAT32);
585
586     uint32_t channel_size = 16;
587     pad->shape({1, channel_size, 4, 4});
588     paddings->shape({4, 2});
589     const_value->shape({1});
590
591     // paddings data (NCHW)
592     // [[0,0], [0,0], [1,1], [2,2]]
593     paddings->size<loco::DataType::S32>(8);
594     for (uint32_t dim = 0; dim < 4; dim++)
595     {
596       for (uint32_t i = 0; i < 2; i++)
597       {
598         int32_t data = 0;
599
600         if (dim == 2)
601           data = 1;
602         else if (dim == 3)
603           data = 2;
604
605         paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
606       }
607     }
608
609     const_value->size<loco::DataType::FLOAT32>(1);
610     const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
611
612     pad->input(input);
613     pad->paddings(paddings);
614     pad->constant_values(paddings);
615
616     pad->name("padV2");
617     paddings->name("paddings");
618     const_value->name("constant_values");
619
620     return pad;
621   }
622
623 public:
624   luci::CirclePadV2 *pad = nullptr;
625   luci::CircleConst *paddings = nullptr;
626   luci::CircleConst *const_value = nullptr;
627 };
628
629 class ReluGraph final : public SimpleGraph
630 {
631 protected:
632   loco::Node *insertGraphBody(loco::Node *input) override
633   {
634     relu = g.nodes()->create<luci::CircleRelu>();
635     relu->features(input);
636     relu->name("Relu");
637
638     return relu;
639   }
640
641 public:
642   luci::CircleRelu *relu = nullptr;
643 };
644
645 class Relu6Graph final : public SimpleGraph
646 {
647 protected:
648   loco::Node *insertGraphBody(loco::Node *input) override
649   {
650     relu6 = g.nodes()->create<luci::CircleRelu6>();
651     relu6->features(input);
652     relu6->name("relu6");
653
654     return relu6;
655   }
656
657 public:
658   luci::CircleRelu6 *relu6 = nullptr;
659 };
660
661 class RsqrtGraph final : public SimpleGraph
662 {
663 protected:
664   loco::Node *insertGraphBody(loco::Node *input) override
665   {
666     rsqrt = g.nodes()->create<luci::CircleRsqrt>();
667     rsqrt->x(input);
668     rsqrt->name("rsqrt");
669
670     return rsqrt;
671   }
672
673 public:
674   luci::CircleRsqrt *rsqrt = nullptr;
675 };
676
677 class SquaredDifferenceGraph final : public SimpleGraph
678 {
679 protected:
680   loco::Node *insertGraphBody(loco::Node *input) override
681   {
682     sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
683     sqdiff->x(input);
684     sqdiff->y(input);
685     sqdiff->name("sqdiff");
686
687     return sqdiff;
688   }
689
690 public:
691   luci::CircleSquaredDifference *sqdiff = nullptr;
692 };
693
694 class SubGraph final : public SimpleGraph
695 {
696 protected:
697   loco::Node *insertGraphBody(loco::Node *input) override
698   {
699     sub = g.nodes()->create<luci::CircleSub>();
700     beta = g.nodes()->create<luci::CircleConst>();
701
702     sub->dtype(loco::DataType::FLOAT32);
703     beta->dtype(loco::DataType::FLOAT32);
704
705     uint32_t channel_size = 16;
706     sub->shape({1, channel_size, 4, 4});
707     beta->shape({1, channel_size, 1, 1});
708
709     beta->size<loco::DataType::FLOAT32>(channel_size);
710     for (uint32_t i = 0; i < channel_size; i++)
711     {
712       beta->at<loco::DataType::FLOAT32>(i) = i;
713     }
714
715     sub->x(input);
716     sub->y(beta);
717
718     sub->name("sub");
719     beta->name("beta");
720
721     return sub;
722   }
723
724 public:
725   void update_const_shape_to_nchw(void)
726   {
727     uint32_t channel_size = 16;
728     beta->shape({1, channel_size, 4, 4});
729
730     beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
731     for (uint32_t i = 0; i < channel_size; i++)
732     {
733       beta->at<loco::DataType::FLOAT32>(i) = i;
734     }
735   }
736
737 public:
738   luci::CircleSub *sub = nullptr;
739   luci::CircleConst *beta = nullptr;
740 };
741
742 class SubScalarGraph final : public SimpleGraph
743 {
744 protected:
745   loco::Node *insertGraphBody(loco::Node *input) override
746   {
747     sub = g.nodes()->create<luci::CircleSub>();
748     beta = g.nodes()->create<luci::CircleConst>();
749
750     sub->dtype(loco::DataType::FLOAT32);
751     beta->dtype(loco::DataType::FLOAT32);
752
753     uint32_t channel_size = 16;
754     sub->shape({1, channel_size, 4, 4});
755     beta->shape({1});
756
757     beta->size<loco::DataType::FLOAT32>(1);
758     beta->at<loco::DataType::FLOAT32>(0) = 5;
759
760     sub->x(beta);
761     sub->y(input);
762
763     sub->name("sub");
764     beta->name("beta");
765
766     return sub;
767   }
768
769 public:
770   luci::CircleSub *sub = nullptr;
771   luci::CircleConst *beta = nullptr;
772 };
773
774 void check_pre_trans(loco::Node *node)
775 {
776   auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
777   EXPECT_NE(nullptr, pre_trans);
778   auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm());
779   EXPECT_NE(nullptr, pre_trans_perm);
780   EXPECT_EQ(1, pre_trans_perm->rank());
781   EXPECT_EQ(4, pre_trans_perm->dim(0).value());
782   EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype());
783   EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0));
784   EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1));
785   EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2));
786   EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3));
787 }
788
789 void check_post_trans(loco::Node *node)
790 {
791   auto post_trans = dynamic_cast<luci::CircleTranspose *>(node);
792   EXPECT_NE(nullptr, post_trans);
793   auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm());
794   EXPECT_NE(nullptr, post_trans_perm);
795   EXPECT_EQ(1, post_trans_perm->rank());
796   EXPECT_EQ(4, post_trans_perm->dim(0).value());
797   EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype());
798   EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0));
799   EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1));
800   EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2));
801   EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3));
802 }
803
804 void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output)
805 {
806   logo::Phase phase;
807
808   // Default passes.
809   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
810
811   // Pass to test
812   phase.emplace_back(
813     std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
814
815   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
816   phase_runner.run(phase);
817 }
818
819 } // namespace
820
821 TEST(ConvertNCHWToNHWCPassTest, name)
822 {
823   luci::ConvertNCHWToNHWCPass pass(false, false);
824   auto const name = pass.name();
825   ASSERT_NE(nullptr, name);
826 }
827
828 TEST(ConvertNCHWToNHWC, Add)
829 {
830   AddGraph g;
831   g.init();
832
833   run_phase(&g.g, false, false);
834
835   auto input_succs = loco::succs(g.input);
836   EXPECT_EQ(1, input_succs.size());
837   check_post_trans(*input_succs.begin());
838
839   check_pre_trans(g.add->x());
840
841   auto add_succs = loco::succs(g.add);
842   EXPECT_EQ(1, add_succs.size());
843   check_post_trans(*add_succs.begin());
844
845   uint32_t channel_size = 16;
846   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
847   EXPECT_NE(nullptr, new_beta);
848   EXPECT_EQ(4, new_beta->rank());
849   EXPECT_EQ(1, new_beta->dim(0).value());
850   EXPECT_EQ(1, new_beta->dim(1).value());
851   EXPECT_EQ(1, new_beta->dim(2).value());
852   EXPECT_EQ(channel_size, new_beta->dim(3).value());
853
854   check_pre_trans(g.output->from());
855 }
856
857 TEST(ConvertNCHWToNHWC, Add_NCHW_const)
858 {
859   AddGraph g;
860   g.init();
861   g.update_const_shape_to_nchw();
862
863   run_phase(&g.g, false, false);
864
865   check_pre_trans(g.add->x());
866
867   auto add_succs = loco::succs(g.add);
868   EXPECT_EQ(1, add_succs.size());
869   check_post_trans(*add_succs.begin());
870
871   uint32_t channel_size = 16;
872   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
873   EXPECT_NE(nullptr, new_beta);
874   EXPECT_EQ(4, new_beta->rank());
875   EXPECT_EQ(1, new_beta->dim(0).value());
876   EXPECT_EQ(4, new_beta->dim(1).value());
877   EXPECT_EQ(4, new_beta->dim(2).value());
878   EXPECT_EQ(channel_size, new_beta->dim(3).value());
879 }
880
881 TEST(ConvertNCHWToNHWC, NHWC_Relu)
882 {
883   // Relu is already NHWC, so it should not be converted
884   // i.e., the graph is not changed
885   NHWCReluGraph g;
886   g.init();
887
888   run_phase(&g.g, false, false);
889
890   EXPECT_EQ(g.pre_reshape, g.relu->features());
891
892   auto relu_succs = loco::succs(g.relu);
893   EXPECT_EQ(1, relu_succs.size());
894   EXPECT_EQ(g.post_reshape, *relu_succs.begin());
895 }
896
897 TEST(ConvertNCHWToNHWC, AddScalar)
898 {
899   AddScalarGraph g;
900   g.init();
901
902   run_phase(&g.g, false, false);
903
904   auto input_succs = loco::succs(g.input);
905   EXPECT_EQ(1, input_succs.size());
906   check_post_trans(*input_succs.begin());
907
908   check_pre_trans(g.add->x());
909
910   auto add_succs = loco::succs(g.add);
911   EXPECT_EQ(1, add_succs.size());
912   check_post_trans(*add_succs.begin());
913
914   auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
915   EXPECT_NE(nullptr, new_beta);
916   EXPECT_EQ(1, new_beta->rank());
917   EXPECT_EQ(1, new_beta->dim(0).value());
918
919   check_pre_trans(g.output->from());
920 }
921
922 TEST(ConvertNCHWToNHWC, Concatenation)
923 {
924   ConcatenationGraph g;
925   g.init();
926
927   run_phase(&g.g, true, true);
928
929   check_pre_trans(g.concat->values(0));
930   check_pre_trans(g.concat->values(1));
931
932   auto concat_succs = loco::succs(g.concat);
933   EXPECT_EQ(1, concat_succs.size());
934   check_post_trans(*concat_succs.begin());
935
936   // Check concat shape, axis
937   EXPECT_EQ(1, g.concat->dim(0).value());
938   EXPECT_EQ(4, g.concat->dim(1).value());
939   EXPECT_EQ(4, g.concat->dim(2).value());
940   EXPECT_EQ(32, g.concat->dim(3).value());
941   EXPECT_EQ(3, g.concat->axis());
942 }
943
944 TEST(ConvertNCHWToNHWC, LeakyRelu)
945 {
946   LeakyReluGraph g;
947   g.init();
948
949   run_phase(&g.g, true, true);
950
951   check_pre_trans(g.leakyrelu->features());
952
953   auto leakyrelu_succs = loco::succs(g.leakyrelu);
954   EXPECT_EQ(1, leakyrelu_succs.size());
955   check_post_trans(*leakyrelu_succs.begin());
956
957   // Check leakyrelu shape
958   EXPECT_EQ(1, g.leakyrelu->dim(0).value());
959   EXPECT_EQ(4, g.leakyrelu->dim(1).value());
960   EXPECT_EQ(4, g.leakyrelu->dim(2).value());
961   EXPECT_EQ(16, g.leakyrelu->dim(3).value());
962 }
963
964 TEST(ConvertNCHWToNHWC, Logistic)
965 {
966   LogisticGraph g;
967   g.init();
968
969   run_phase(&g.g, true, true);
970
971   check_pre_trans(g.logistic->x());
972
973   auto logistic_succs = loco::succs(g.logistic);
974   EXPECT_EQ(1, logistic_succs.size());
975   check_post_trans(*logistic_succs.begin());
976
977   // Check logistic shape
978   EXPECT_EQ(1, g.logistic->dim(0).value());
979   EXPECT_EQ(4, g.logistic->dim(1).value());
980   EXPECT_EQ(4, g.logistic->dim(2).value());
981   EXPECT_EQ(16, g.logistic->dim(3).value());
982 }
983
984 TEST(ConvertNCHWToNHWC, Maximum)
985 {
986   MaximumGraph g;
987   g.init();
988
989   run_phase(&g.g, false, false);
990
991   auto input_succs = loco::succs(g.input);
992   EXPECT_EQ(1, input_succs.size());
993   check_post_trans(*input_succs.begin());
994
995   check_pre_trans(g.max->x());
996
997   auto max_succs = loco::succs(g.max);
998   EXPECT_EQ(1, max_succs.size());
999   check_post_trans(*max_succs.begin());
1000
1001   check_pre_trans(g.output->from());
1002 }
1003
1004 TEST(ConvertNCHWToNHWC, Mean)
1005 {
1006   MeanGraph g;
1007   g.init();
1008
1009   run_phase(&g.g, false, false);
1010
1011   check_pre_trans(g.mean->input());
1012
1013   auto mean_succs = loco::succs(g.mean);
1014   EXPECT_EQ(1, mean_succs.size());
1015   check_post_trans(*mean_succs.begin());
1016
1017   auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1018   EXPECT_NE(nullptr, new_rindices);
1019   EXPECT_EQ(1, new_rindices->rank());
1020   EXPECT_EQ(2, new_rindices->dim(0).value());
1021   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1022   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1023   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1024 }
1025
1026 TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
1027 {
1028   struct TC
1029   {
1030     std::vector<int32_t> nchw_ind;
1031     std::vector<int32_t> nhwc_ind;
1032     std::initializer_list<uint32_t> shape;
1033     bool needs_transpose = false;
1034   };
1035
1036   uint32_t n = 1;
1037   uint32_t c = 16;
1038   uint32_t h = 4;
1039   uint32_t w = 4;
1040
1041   std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true},       {{1}, {3}, {n, h, w}, false},
1042                              {{2}, {1}, {n, c, w}, true},       {{3}, {2}, {n, c, h}, true},
1043                              {{0, 1}, {0, 3}, {h, w}, false},   {{0, 2}, {0, 1}, {c, w}, true},
1044                              {{0, 3}, {0, 2}, {c, h}, true},    {{1, 2}, {3, 1}, {n, w}, false},
1045                              {{1, 3}, {3, 2}, {n, h}, false},   {{2, 3}, {1, 2}, {n, c}, false},
1046                              {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1047
1048   for (auto &tc : test_cases)
1049   {
1050     MeanGraph g;
1051     g.keep_dims(false);
1052     g.axes(tc.nchw_ind);
1053     g.shape(tc.shape);
1054     g.init();
1055
1056     run_phase(&g.g, false, true);
1057
1058     check_pre_trans(g.mean->input());
1059
1060     auto mean_succs = loco::succs(g.mean);
1061     EXPECT_EQ(1, mean_succs.size());
1062     if (tc.needs_transpose)
1063     {
1064       EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
1065     }
1066     else
1067     {
1068       EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
1069     }
1070
1071     auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1072     EXPECT_NE(nullptr, new_rindices);
1073     EXPECT_EQ(1, new_rindices->rank());
1074     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1075     EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1076     for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1077     {
1078       EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1079     }
1080   }
1081 }
1082
1083 TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
1084 {
1085   loco::Graph g;
1086   auto input = g.nodes()->create<luci::CircleInput>();
1087   auto output = g.nodes()->create<luci::CircleOutput>();
1088   input->name("input");
1089   output->name("output");
1090
1091   auto graph_input = g.inputs()->create();
1092   input->index(graph_input->index());
1093   auto graph_output = g.outputs()->create();
1094   output->index(graph_output->index());
1095
1096   graph_input->dtype(loco::DataType::FLOAT32);
1097   input->dtype(loco::DataType::FLOAT32);
1098   output->dtype(loco::DataType::FLOAT32);
1099   graph_output->dtype(loco::DataType::FLOAT32);
1100
1101   uint32_t channel_size = 16;
1102   graph_input->shape({channel_size, 4, 4});
1103   input->shape({channel_size, 4, 4});
1104   output->shape({channel_size});
1105   graph_output->shape({channel_size});
1106
1107   auto mean = g.nodes()->create<luci::CircleMean>();
1108   auto rindices = g.nodes()->create<luci::CircleConst>();
1109
1110   mean->dtype(loco::DataType::FLOAT32);
1111   rindices->dtype(loco::DataType::S32);
1112
1113   mean->shape({channel_size});
1114   rindices->shape({2});
1115
1116   rindices->size<loco::DataType::S32>(2);
1117   rindices->at<loco::DataType::S32>(0) = 1;
1118   rindices->at<loco::DataType::S32>(1) = 2;
1119
1120   mean->input(input);
1121   mean->reduction_indices(rindices);
1122   mean->keep_dims(false);
1123
1124   mean->name("mean");
1125   rindices->name("rindices");
1126
1127   output->from(mean);
1128
1129   run_phase(&g, true, true);
1130
1131   auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
1132   EXPECT_NE(nullptr, new_rindices);
1133   EXPECT_EQ(1, new_rindices->rank());
1134   EXPECT_EQ(2, new_rindices->dim(0).value());
1135   EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1136   EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1137   EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1138 }
1139
1140 TEST(ConvertNCHWToNHWC, Minimum)
1141 {
1142   MinimumGraph g;
1143   g.init();
1144
1145   run_phase(&g.g, false, false);
1146
1147   auto input_succs = loco::succs(g.input);
1148   EXPECT_EQ(1, input_succs.size());
1149   check_post_trans(*input_succs.begin());
1150
1151   check_pre_trans(g.min->x());
1152
1153   auto min_succs = loco::succs(g.min);
1154   EXPECT_EQ(1, min_succs.size());
1155   check_post_trans(*min_succs.begin());
1156
1157   check_pre_trans(g.output->from());
1158 }
1159
1160 TEST(ConvertNCHWToNHWC, Mul)
1161 {
1162   MulGraph g;
1163   g.init();
1164
1165   run_phase(&g.g, false, false);
1166
1167   auto input_succs = loco::succs(g.input);
1168   EXPECT_EQ(1, input_succs.size());
1169   check_post_trans(*input_succs.begin());
1170
1171   check_pre_trans(g.mul->x());
1172
1173   auto mul_succs = loco::succs(g.mul);
1174   EXPECT_EQ(1, mul_succs.size());
1175   check_post_trans(*mul_succs.begin());
1176
1177   uint32_t channel_size = 16;
1178   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1179   EXPECT_NE(nullptr, new_multiplier);
1180   EXPECT_EQ(4, new_multiplier->rank());
1181   EXPECT_EQ(1, new_multiplier->dim(0).value());
1182   EXPECT_EQ(1, new_multiplier->dim(1).value());
1183   EXPECT_EQ(1, new_multiplier->dim(2).value());
1184   EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1185
1186   check_pre_trans(g.output->from());
1187 }
1188
1189 TEST(ConvertNCHWToNHWC, Mul_NCHW_const)
1190 {
1191   MulGraph g;
1192   g.init();
1193   g.update_const_shape_to_nchw();
1194
1195   run_phase(&g.g, false, false);
1196
1197   check_pre_trans(g.mul->x());
1198
1199   auto mul_succs = loco::succs(g.mul);
1200   EXPECT_EQ(1, mul_succs.size());
1201   check_post_trans(*mul_succs.begin());
1202
1203   uint32_t channel_size = 16;
1204   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1205   EXPECT_NE(nullptr, new_multiplier);
1206   EXPECT_EQ(4, new_multiplier->rank());
1207   EXPECT_EQ(1, new_multiplier->dim(0).value());
1208   EXPECT_EQ(4, new_multiplier->dim(1).value());
1209   EXPECT_EQ(4, new_multiplier->dim(2).value());
1210   EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1211 }
1212
1213 TEST(ConvertNCHWToNHWC, MulScalar)
1214 {
1215   MulScalarGraph g;
1216   g.init();
1217
1218   run_phase(&g.g, false, false);
1219
1220   auto input_succs = loco::succs(g.input);
1221   EXPECT_EQ(1, input_succs.size());
1222   check_post_trans(*input_succs.begin());
1223
1224   check_pre_trans(g.mul->x());
1225
1226   auto mul_succs = loco::succs(g.mul);
1227   EXPECT_EQ(1, mul_succs.size());
1228   check_post_trans(*mul_succs.begin());
1229
1230   auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1231   EXPECT_NE(nullptr, new_multiplier);
1232   EXPECT_EQ(1, new_multiplier->rank());
1233   EXPECT_EQ(1, new_multiplier->dim(0).value());
1234
1235   check_pre_trans(g.output->from());
1236 }
1237
1238 TEST(ConvertNCHWToNHWC, MulBothNorm)
1239 {
1240   MulBothNormGraph g;
1241   g.init();
1242
1243   run_phase(&g.g, false, false);
1244
1245   auto input_succs = loco::succs(g.input);
1246   EXPECT_EQ(1, input_succs.size());
1247   check_post_trans(*input_succs.begin());
1248
1249   check_pre_trans(g.mul->x());
1250   check_pre_trans(g.mul->y());
1251
1252   auto mul_succs = loco::succs(g.mul);
1253   EXPECT_EQ(1, mul_succs.size());
1254   check_post_trans(*mul_succs.begin());
1255
1256   check_pre_trans(g.output->from());
1257 }
1258
1259 TEST(ConvertNCHWToNHWC, Neg)
1260 {
1261   NegGraph g;
1262   g.init();
1263
1264   run_phase(&g.g, true, true);
1265
1266   check_pre_trans(g.neg->x());
1267
1268   auto neg_succs = loco::succs(g.neg);
1269   EXPECT_EQ(1, neg_succs.size());
1270   check_post_trans(*neg_succs.begin());
1271
1272   // Check leakyrelu shape
1273   EXPECT_EQ(1, g.neg->dim(0).value());
1274   EXPECT_EQ(4, g.neg->dim(1).value());
1275   EXPECT_EQ(4, g.neg->dim(2).value());
1276   EXPECT_EQ(16, g.neg->dim(3).value());
1277 }
1278
1279 TEST(ConvertNCHWToNHWC, Pad)
1280 {
1281   PadGraph g;
1282   g.init();
1283
1284   run_phase(&g.g, false, false);
1285
1286   auto input_succs = loco::succs(g.input);
1287   EXPECT_EQ(1, input_succs.size());
1288   check_post_trans(*input_succs.begin());
1289
1290   check_pre_trans(g.pad->input());
1291
1292   auto pad_succs = loco::succs(g.pad);
1293   EXPECT_EQ(1, pad_succs.size());
1294   check_post_trans(*pad_succs.begin());
1295
1296   auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1297   EXPECT_NE(nullptr, new_paddings);
1298   EXPECT_EQ(2, new_paddings->rank());
1299   EXPECT_EQ(4, new_paddings->dim(0).value());
1300   EXPECT_EQ(2, new_paddings->dim(1).value());
1301   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1302   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1303   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1304   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1305   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1306   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1307   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1308   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1309
1310   check_pre_trans(g.output->from());
1311 }
1312
1313 TEST(ConvertNCHWToNHWC, PadV2)
1314 {
1315   PadV2Graph g;
1316   g.init();
1317
1318   run_phase(&g.g, false, false);
1319
1320   check_pre_trans(g.pad->input());
1321
1322   auto pad_succs = loco::succs(g.pad);
1323   EXPECT_EQ(1, pad_succs.size());
1324   check_post_trans(*pad_succs.begin());
1325
1326   auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1327   EXPECT_NE(nullptr, new_paddings);
1328   EXPECT_EQ(2, new_paddings->rank());
1329   EXPECT_EQ(4, new_paddings->dim(0).value());
1330   EXPECT_EQ(2, new_paddings->dim(1).value());
1331   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1332   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1333   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1334   EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1335   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1336   EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1337   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1338   EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1339 }
1340
1341 TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
1342 {
1343   AddGraph g;
1344   g.init();
1345
1346   // Unknown shape
1347   g.input->dim(0).unset();
1348   g.add->dim(0).unset();
1349   g.output->dim(0).unset();
1350
1351   luci::ConvertNCHWToNHWCPass pass(false, false);
1352   EXPECT_EQ(false, pass.run(&g.g));
1353 }
1354
1355 TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
1356 {
1357   // Preserve input
1358   {
1359     AddGraph g;
1360     g.init();
1361
1362     run_phase(&g.g, true, false);
1363
1364     // Check input shape
1365     EXPECT_EQ(1, g.input->dim(0).value());
1366     EXPECT_EQ(16, g.input->dim(1).value());
1367     EXPECT_EQ(4, g.input->dim(2).value());
1368     EXPECT_EQ(4, g.input->dim(3).value());
1369
1370     // Check output shape
1371     EXPECT_EQ(1, g.output->dim(0).value());
1372     EXPECT_EQ(4, g.output->dim(1).value());
1373     EXPECT_EQ(4, g.output->dim(2).value());
1374     EXPECT_EQ(16, g.output->dim(3).value());
1375   }
1376
1377   // Preserve output
1378   {
1379     AddGraph g;
1380     g.init();
1381
1382     run_phase(&g.g, false, true);
1383
1384     // Check input shape
1385     EXPECT_EQ(1, g.input->dim(0).value());
1386     EXPECT_EQ(4, g.input->dim(1).value());
1387     EXPECT_EQ(4, g.input->dim(2).value());
1388     EXPECT_EQ(16, g.input->dim(3).value());
1389
1390     // Check output shape
1391     EXPECT_EQ(1, g.output->dim(0).value());
1392     EXPECT_EQ(16, g.output->dim(1).value());
1393     EXPECT_EQ(4, g.output->dim(2).value());
1394     EXPECT_EQ(4, g.output->dim(3).value());
1395   }
1396
1397   // Preserve both input and output
1398   {
1399     AddGraph g;
1400     g.init();
1401
1402     run_phase(&g.g, true, true);
1403
1404     // Check input shape
1405     EXPECT_EQ(1, g.input->dim(0).value());
1406     EXPECT_EQ(16, g.input->dim(1).value());
1407     EXPECT_EQ(4, g.input->dim(2).value());
1408     EXPECT_EQ(4, g.input->dim(3).value());
1409
1410     // Check output shape
1411     EXPECT_EQ(1, g.output->dim(0).value());
1412     EXPECT_EQ(16, g.output->dim(1).value());
1413     EXPECT_EQ(4, g.output->dim(2).value());
1414     EXPECT_EQ(4, g.output->dim(3).value());
1415   }
1416 }
1417
1418 TEST(ConvertNCHWToNHWC, Relu)
1419 {
1420   ReluGraph g;
1421   g.init();
1422
1423   run_phase(&g.g, true, true);
1424
1425   check_pre_trans(g.relu->features());
1426
1427   auto relu_succs = loco::succs(g.relu);
1428   EXPECT_EQ(1, relu_succs.size());
1429   check_post_trans(*relu_succs.begin());
1430
1431   // Check relu shape
1432   EXPECT_EQ(1, g.relu->dim(0).value());
1433   EXPECT_EQ(4, g.relu->dim(1).value());
1434   EXPECT_EQ(4, g.relu->dim(2).value());
1435   EXPECT_EQ(16, g.relu->dim(3).value());
1436 }
1437
1438 TEST(ConvertNCHWToNHWC, Relu6)
1439 {
1440   Relu6Graph g;
1441   g.init();
1442
1443   run_phase(&g.g, true, true);
1444
1445   check_pre_trans(g.relu6->features());
1446
1447   auto relu6_succs = loco::succs(g.relu6);
1448   EXPECT_EQ(1, relu6_succs.size());
1449   check_post_trans(*relu6_succs.begin());
1450
1451   // Check relu6 shape
1452   EXPECT_EQ(1, g.relu6->dim(0).value());
1453   EXPECT_EQ(4, g.relu6->dim(1).value());
1454   EXPECT_EQ(4, g.relu6->dim(2).value());
1455   EXPECT_EQ(16, g.relu6->dim(3).value());
1456 }
1457
1458 TEST(ConvertNCHWToNHWC, Rsqrt)
1459 {
1460   RsqrtGraph g;
1461   g.init();
1462
1463   run_phase(&g.g, true, true);
1464
1465   check_pre_trans(g.rsqrt->x());
1466
1467   auto rsqrt_succs = loco::succs(g.rsqrt);
1468   EXPECT_EQ(1, rsqrt_succs.size());
1469   check_post_trans(*rsqrt_succs.begin());
1470
1471   // Check rsqrt shape
1472   EXPECT_EQ(1, g.rsqrt->dim(0).value());
1473   EXPECT_EQ(4, g.rsqrt->dim(1).value());
1474   EXPECT_EQ(4, g.rsqrt->dim(2).value());
1475   EXPECT_EQ(16, g.rsqrt->dim(3).value());
1476 }
1477
1478 TEST(ConvertNCHWToNHWC, SquaredDifference)
1479 {
1480   SquaredDifferenceGraph g;
1481   g.init();
1482
1483   run_phase(&g.g, true, true);
1484
1485   check_pre_trans(g.sqdiff->x());
1486   check_pre_trans(g.sqdiff->y());
1487
1488   auto sqdiff_succs = loco::succs(g.sqdiff);
1489   EXPECT_EQ(1, sqdiff_succs.size());
1490   check_post_trans(*sqdiff_succs.begin());
1491 }
1492
1493 TEST(ConvertNCHWToNHWC, Sub)
1494 {
1495   SubGraph g;
1496   g.init();
1497
1498   run_phase(&g.g, false, false);
1499
1500   auto input_succs = loco::succs(g.input);
1501   EXPECT_EQ(1, input_succs.size());
1502   check_post_trans(*input_succs.begin());
1503
1504   check_pre_trans(g.sub->x());
1505
1506   auto add_succs = loco::succs(g.sub);
1507   EXPECT_EQ(1, add_succs.size());
1508   check_post_trans(*add_succs.begin());
1509
1510   uint32_t channel_size = 16;
1511   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
1512   EXPECT_NE(nullptr, new_beta);
1513   EXPECT_EQ(4, new_beta->rank());
1514   EXPECT_EQ(1, new_beta->dim(0).value());
1515   EXPECT_EQ(1, new_beta->dim(1).value());
1516   EXPECT_EQ(1, new_beta->dim(2).value());
1517   EXPECT_EQ(channel_size, new_beta->dim(3).value());
1518
1519   check_pre_trans(g.output->from());
1520 }
1521
1522 TEST(ConvertNCHWToNHWC, Sub_NCHW_const)
1523 {
1524   SubGraph g;
1525   g.init();
1526   g.update_const_shape_to_nchw();
1527
1528   run_phase(&g.g, false, false);
1529
1530   check_pre_trans(g.sub->x());
1531
1532   auto sub_succs = loco::succs(g.sub);
1533   EXPECT_EQ(1, sub_succs.size());
1534   check_post_trans(*sub_succs.begin());
1535
1536   uint32_t channel_size = 16;
1537   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
1538   EXPECT_NE(nullptr, new_beta);
1539   EXPECT_EQ(4, new_beta->rank());
1540   EXPECT_EQ(1, new_beta->dim(0).value());
1541   EXPECT_EQ(4, new_beta->dim(1).value());
1542   EXPECT_EQ(4, new_beta->dim(2).value());
1543   EXPECT_EQ(channel_size, new_beta->dim(3).value());
1544 }
1545
1546 TEST(ConvertNCHWToNHWC, SubScalar)
1547 {
1548   SubScalarGraph g;
1549   g.init();
1550
1551   run_phase(&g.g, false, false);
1552
1553   auto input_succs = loco::succs(g.input);
1554   EXPECT_EQ(1, input_succs.size());
1555   check_post_trans(*input_succs.begin());
1556
1557   check_pre_trans(g.sub->y());
1558
1559   auto add_succs = loco::succs(g.sub);
1560   EXPECT_EQ(1, add_succs.size());
1561   check_post_trans(*add_succs.begin());
1562
1563   auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
1564   EXPECT_NE(nullptr, new_beta);
1565   EXPECT_EQ(1, new_beta->rank());
1566
1567   check_pre_trans(g.output->from());
1568 }