Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ConvertNCHWToNHWCPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
18 #include "CircleOptimizerUtils.h"
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/Service/Nodes/CircleConst.h>
24 #include <luci/Log.h>
25
26 #include <functional>
27
28 namespace
29 {
30
31 // Return true if from can be broadcasted to to
32 // to's shape is [N, C, H, W]
33 bool broadcastable(const luci::CircleConst *from, const luci::CircleNode *to)
34 {
35   assert(to->rank() == 4); // FIX_CALLER_UNLESS
36
37   const auto from_rank = from->rank();
38   if (from_rank > 4)
39     return false;
40
41   // Scalar is always broadcastable
42   if (from_rank == 0)
43     return true;
44
45   for (uint32_t i = 1; i <= from_rank; i++)
46   {
47     auto to_index = 4 - i;
48     auto from_index = from_rank - i;
49
50     if (from->dim(from_index).value() != to->dim(to_index).value() and
51         from->dim(from_index).value() != 1)
52       return false;
53   }
54
55   return true;
56 }
57
58 // Return node with rank 4
59 // node should have rank less than or equal to 4
60 // 1 is inserted to the front of shape if rank is less than 4
61 // For example, [2] -> [1, 1, 1, 2]
62 luci::CircleConst *expand_to_rank_4(luci::CircleConst *node)
63 {
64   auto original_rank = node->rank();
65
66   assert(original_rank <= 4); // FIX_CALLER_UNLESS
67
68   if (original_rank == 4)
69     return node;
70
71   std::vector<uint32_t> original_shape;
72   for (uint32_t i = 0; i < original_rank; i++)
73   {
74     original_shape.emplace_back(node->dim(i).value());
75   }
76
77   auto cloned = luci::clone(node);
78   cloned->name(cloned->name() + "_rank4");
79
80   cloned->rank(4);
81   for (uint32_t i = 0; i < (4 - original_rank); i++)
82     cloned->dim(i) = 1;
83
84   for (uint32_t i = 0; i < original_rank; i++)
85     cloned->dim(i + (4 - original_rank)) = original_shape.at(i);
86
87   return cloned;
88 }
89
90 bool is_output(const loco::Node *node)
91 {
92   auto cnode = loco::must_cast<const luci::CircleNode *>(node);
93   auto opcode = cnode->opcode();
94   if (opcode == luci::CircleOpcode::CIRCLEOUTPUT ||
95       opcode == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
96     return true;
97
98   return false;
99 }
100
101 bool is_same_shape(const luci::CircleNode *node, const std::vector<loco::Dimension> &shape)
102 {
103   if (not node)
104     return false;
105
106   if (shape.size() != node->rank())
107     return false;
108
109   for (uint32_t i = 0; i < shape.size(); i++)
110   {
111     if (not(node->dim(i) == shape[i]))
112       return false;
113   }
114   return true;
115 }
116
117 enum class DataFormat
118 {
119   NCHW,
120   NHWC
121 };
122
123 /**
124  * @brief Set annotation for DataFormat (NCHW, NHWC)
125  *
126  * @note DataFormatAnnotation will live longer than this Pass (until the
127  *       annotated loco::Node is erased). So, do not use large data in the
128  *       annotation to avoid excessive memory usage.
129  */
130 class DataFormatAnnotation final : public loco::NodeAnnotation
131 {
132 public:
133   DataFormatAnnotation(const DataFormat &format) : _format{format}
134   {
135     // DO NOTHING
136   }
137
138 public:
139   const DataFormat &format(void) const { return _format; }
140
141 private:
142   DataFormat _format;
143 };
144
145 void set_data_format(loco::Node *node, const DataFormat &format)
146 {
147   node->annot(std::make_unique<DataFormatAnnotation>(format));
148 }
149
150 DataFormat get_data_format(loco::Node *node)
151 {
152   assert(node->annot<DataFormatAnnotation>() != nullptr);
153   return node->annot<DataFormatAnnotation>()->format();
154 }
155
156 bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; }
157
158 bool check_4d_transpose(loco::Node *node, const std::vector<int32_t> indices)
159 {
160   assert(indices.size() == 4);
161
162   auto trans = dynamic_cast<luci::CircleTranspose *>(node);
163   if (not trans)
164     return false;
165
166   if (not trans->perm())
167     return false;
168
169   auto perm = dynamic_cast<luci::CircleConst *>(trans->perm());
170   // Only const perm is supported
171   if (not perm)
172     return false;
173
174   if (perm->dtype() != loco::DataType::S32)
175     return false;
176
177   if (perm->size<loco::DataType::S32>() != 4)
178     return false;
179
180   for (uint32_t i = 0; i < 4; i++)
181   {
182     if (perm->at<loco::DataType::S32>(i) != indices[i])
183       return false;
184   }
185
186   return true;
187 }
188
189 luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
190                                            const std::vector<int32_t> indices)
191 {
192   assert(indices.size() == 4);
193
194   auto name = node->name();
195   assert(name.length() > 0);
196
197   auto perm = node->graph()->nodes()->create<luci::CircleConst>();
198   perm->dtype(loco::DataType::S32);
199   perm->size<loco::DataType::S32>(4);
200   perm->rank(1);
201   perm->dim(0) = 4;
202   for (uint32_t i = 0; i < 4; i++)
203     perm->at<loco::DataType::S32>(i) = indices[i];
204   perm->shape_status(luci::ShapeStatus::VALID);
205
206   auto make_string = [](const std::vector<int32_t> &nums) {
207     std::string str;
208     for (auto num : nums)
209     {
210       if (str.length() > 0)
211         str += ".";
212       str += std::to_string(num);
213     }
214     return str;
215   };
216
217   auto str_indices = make_string(indices);
218
219   perm->name(name + "/Transpose_" + str_indices + "/perm");
220
221   auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
222   trans->perm(perm);
223   trans->name(name + "/Transpose_" + str_indices);
224   luci::add_origin(trans, luci::get_origin(node));
225
226   return trans;
227 }
228
229 luci::CircleTranspose *create_Nd_transpose(luci::CircleNode *node,
230                                            const std::vector<int32_t> indices)
231 {
232   auto name = node->name();
233   assert(name.length() > 0);
234
235   auto perm = node->graph()->nodes()->create<luci::CircleConst>();
236   perm->dtype(loco::DataType::S32);
237   perm->size<loco::DataType::S32>(indices.size());
238   perm->rank(1);
239   perm->dim(0) = indices.size();
240   for (uint32_t i = 0; i < indices.size(); i++)
241     perm->at<loco::DataType::S32>(i) = indices[i];
242   perm->shape_status(luci::ShapeStatus::VALID);
243
244   auto make_string = [](const std::vector<int32_t> &nums) {
245     std::string str;
246     for (auto num : nums)
247     {
248       if (str.length() > 0)
249         str += ".";
250       str += std::to_string(num);
251     }
252     return str;
253   };
254
255   auto str_indices = make_string(indices);
256
257   perm->name(name + "/Transpose_" + str_indices + "/perm");
258
259   auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
260   trans->perm(perm);
261   trans->name(name + "/Transpose_" + str_indices);
262   luci::add_origin(trans, luci::get_origin(node));
263
264   return trans;
265 }
266
267 int32_t nchw_axis_to_nhwc(int32_t axis)
268 {
269   uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4);
270   static const uint32_t to_nhwc[4] = {0, 3, 1, 2};
271   if (pos_axis > 3)
272     throw std::runtime_error("Concat axis must be in range [-4, 4)");
273   return to_nhwc[pos_axis];
274 }
275
276 luci::CircleTranspose *create_post_transpose(luci::CircleNode *node)
277 {
278   return create_4d_transpose(node, {0, 3, 1, 2});
279 }
280
281 luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node)
282 {
283   return create_4d_transpose(node, {0, 2, 3, 1});
284 }
285
286 bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
287 {
288   assert(indices.size() == 4); // FIX_CALLER_UNLESS
289
290   auto reshape = dynamic_cast<luci::CircleReshape *>(node);
291   if (not reshape)
292     return false;
293
294   if (reshape->rank() != 4)
295     return false;
296
297   auto input = loco::must_cast<luci::CircleNode *>(reshape->tensor());
298   if (input->shape_status() != luci::ShapeStatus::VALID)
299     return false;
300
301   if (input->rank() != 4)
302     return false;
303
304   if (reshape->shape_status() != luci::ShapeStatus::VALID)
305     return false;
306
307   if (!(input->dim(0) == reshape->dim(indices[0])) ||
308       !(input->dim(1) == reshape->dim(indices[1])) ||
309       !(input->dim(2) == reshape->dim(indices[2])) || !(input->dim(3) == reshape->dim(indices[3])))
310     return false;
311
312   return true;
313 }
314
315 // Check if Reshape that converts NCHW -> NHWC
316 bool is_pre_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 3, 1, 2}); }
317
318 // Check if Reshape that converts NHWC -> NCHW
319 bool is_post_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 2, 3, 1}); }
320
321 bool is_post_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 3, 1, 2}); }
322
323 bool is_pre_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 2, 3, 1}); }
324
325 uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices)
326 {
327   return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
328            dimension.dim(3).value() +
329          indices[1] * dimension.dim(2).value() * dimension.dim(3).value() +
330          indices[2] * dimension.dim(3).value() + indices[3];
331 }
332
333 luci::CircleConst *create_NHWC_paddings(luci::CircleConst *paddings)
334 {
335   // paddings shape is (4,2) (it was checked by is_NCHW)
336   assert(paddings != nullptr);
337   assert(paddings->rank() == 2);
338   assert(paddings->dim(0).value() == 4);
339   assert(paddings->dim(1).value() == 2);
340
341   // paddings for idx 0~3 are 0 (checked by is_NCHW)
342   assert(paddings->at<loco::DataType::S32>(0) == 0);
343   assert(paddings->at<loco::DataType::S32>(1) == 0);
344   assert(paddings->at<loco::DataType::S32>(2) == 0);
345   assert(paddings->at<loco::DataType::S32>(3) == 0);
346
347   auto name = paddings->name();
348   assert(name.length() > 0);
349
350   auto nhwc_paddings = paddings->graph()->nodes()->create<luci::CircleConst>();
351   nhwc_paddings->dtype(loco::DataType::S32);
352   nhwc_paddings->shape({4, 2});
353   nhwc_paddings->shape_status(luci::ShapeStatus::VALID);
354   nhwc_paddings->size<loco::DataType::S32>(4 * 2);
355   nhwc_paddings->name(name + "_NHWC");
356
357   for (uint32_t dim = 0; dim < 4; dim++)
358   {
359     for (uint32_t i = 0; i < 2; i++)
360     {
361       int32_t data = 0;
362
363       if (dim == 1)
364       {
365         // get third dimension (H in NCHW)
366         data = paddings->at<loco::DataType::S32>(2 * 2 + i);
367       }
368       else if (dim == 2)
369       {
370         // get fourth dimension (W in NCHW)
371         data = paddings->at<loco::DataType::S32>(3 * 2 + i);
372       }
373
374       nhwc_paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
375     }
376   }
377   return nhwc_paddings;
378 }
379
380 luci::CircleConst *create_NHWC_rindices(luci::CircleConst *rindices)
381 {
382   assert(rindices != nullptr); // FIX_CALLER_UNLESS
383
384   if (rindices->dtype() != loco::DataType::S32)
385     return nullptr;
386
387   auto nhwc_rindices = luci::clone(rindices);
388   auto name = rindices->name();
389   assert(name.length() > 0); // FIX_CALLER_UNLESS
390   nhwc_rindices->name(name + "_NHWC");
391
392   auto size = nhwc_rindices->size<loco::DataType::S32>();
393   for (uint32_t i = 0; i < size; i++)
394   {
395     nhwc_rindices->at<loco::DataType::S32>(i) =
396       nchw_axis_to_nhwc(rindices->at<loco::DataType::S32>(i));
397   }
398
399   return nhwc_rindices;
400 }
401
402 luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant)
403 {
404   LOGGER(l);
405   assert(constant->rank() == 4);
406
407   // TODO: Support non-float types
408   if (constant->dtype() != loco::DataType::FLOAT32)
409   {
410     INFO(l) << "Non-float type constant: " << constant->name() << std::endl;
411     return nullptr;
412   }
413
414   loco::TensorShape nchw_dimension{constant->dim(0), constant->dim(1), constant->dim(2),
415                                    constant->dim(3)};
416   loco::TensorShape nhwc_dimension{constant->dim(0), constant->dim(2), constant->dim(3),
417                                    constant->dim(1)};
418
419   auto name = constant->name();
420   assert(name.length() > 0);
421
422   auto nhwc_const = constant->graph()->nodes()->create<luci::CircleConst>();
423   nhwc_const->dtype(constant->dtype());
424   nhwc_const->rank(4);
425   nhwc_const->dim(0).set(constant->dim(0).value());
426   nhwc_const->dim(1).set(constant->dim(2).value());
427   nhwc_const->dim(2).set(constant->dim(3).value());
428   nhwc_const->dim(3).set(constant->dim(1).value());
429   nhwc_const->shape_status(luci::ShapeStatus::VALID);
430   nhwc_const->size<loco::DataType::FLOAT32>(constant->size<loco::DataType::FLOAT32>());
431   nhwc_const->name(name + "_NHWC");
432
433   for (uint32_t n = 0; n < nchw_dimension.dim(0).value(); n++)
434   {
435     for (uint32_t c = 0; c < nchw_dimension.dim(1).value(); c++)
436     {
437       for (uint32_t h = 0; h < nchw_dimension.dim(2).value(); h++)
438       {
439         for (uint32_t w = 0; w < nchw_dimension.dim(3).value(); w++)
440         {
441           uint32_t nchw_indices[4] = {n, c, h, w};
442           uint32_t nhwc_indices[4] = {n, h, w, c};
443           auto data =
444             constant->at<loco::DataType::FLOAT32>(cal_offset(nchw_dimension, nchw_indices));
445           nhwc_const->at<loco::DataType::FLOAT32>(cal_offset(nhwc_dimension, nhwc_indices)) = data;
446         }
447       }
448     }
449   }
450   return nhwc_const;
451 }
452
453 // NOTE Following conditions can be extended later
454 //
455 // Find PAD with an NCHW pattern described below
456 //   - Paddings shape : [4, 2]
457 //   - Paddings value : [[0, 0], [0, 0], [h_t, h_b], [w_t, w_b]]]
458 bool is_NCHW(const luci::CirclePad *node)
459 {
460   const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
461   // Non-const paddings is not supported
462   if (paddings == nullptr)
463     return false;
464
465   if (paddings->rank() != 2)
466     return false;
467
468   if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
469     return false;
470
471   // Only check the first two dimensions
472   for (uint32_t dim = 0; dim < 2; dim++)
473   {
474     for (uint32_t i = 0; i < 2; i++)
475     {
476       auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
477       if (data != 0)
478         return false;
479     }
480   }
481
482   return true;
483 }
484
485 // NOTE Copied from is_NCHW(CirclePad)
486 bool is_NCHW(const luci::CirclePadV2 *node)
487 {
488   const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
489   // Non-const paddings is not supported
490   if (paddings == nullptr)
491     return false;
492
493   if (paddings->rank() != 2)
494     return false;
495
496   if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
497     return false;
498
499   // Only check the first two dimensions
500   for (uint32_t dim = 0; dim < 2; dim++)
501   {
502     for (uint32_t i = 0; i < 2; i++)
503     {
504       auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
505       if (data != 0)
506         return false;
507     }
508   }
509
510   return true;
511 }
512
513 bool is_const(const loco::Node *node)
514 {
515   if (not dynamic_cast<const luci::CircleConst *>(node))
516     return false;
517
518   return true;
519 }
520
521 bool is_scalar_const(const loco::Node *node)
522 {
523   auto const_node = dynamic_cast<const luci::CircleConst *>(node);
524   if (not const_node)
525     return false;
526
527   const auto const_rank = const_node->rank();
528   // shape of scalar
529   // 1. rank = 0
530   // 2. rank = 1, dimension = 1
531   if (const_rank == 0)
532     return true;
533
534   if (const_rank == 1 && const_node->dim(0).value() == 1)
535     return true;
536
537   return false;
538 }
539
540 // NOTE Following conditions can be extended later
541 //
542 // Find MUL with an NCHW pattern described below
543 //   - Input (non-constant) shape : [N, C, H, W]
544 //   - Input (constant) shape : broadcastable to [N, C, H, W]
545 //   - Output shape : [N, C, H, W]
546 bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
547                         luci::CircleConst *&multiplier)
548 {
549   auto x = dynamic_cast<luci::CircleConst *>(node->x());
550   auto y = dynamic_cast<luci::CircleConst *>(node->y());
551
552   if (x != nullptr && y == nullptr)
553   {
554     pred_node = loco::must_cast<luci::CircleNode *>(node->y());
555     multiplier = x;
556   }
557   else if (x == nullptr && y != nullptr)
558   {
559     pred_node = loco::must_cast<luci::CircleNode *>(node->x());
560     multiplier = y;
561   }
562   else
563   {
564     // Ignore if MUL does not have a multiplier input.
565     return false;
566   }
567
568   if (pred_node->rank() != 4)
569     return false;
570
571   if (not broadcastable(multiplier, node))
572     return false;
573
574   multiplier = expand_to_rank_4(multiplier);
575
576   return true;
577 }
578
579 // We assume ADD with const input is NCHW if,
580 // Input shape: (N, C, H, W)
581 // Output shape: (N, C, H, W)
582 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
583 // 2. Input, Output, Const have the same C.
584 bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node,
585                         luci::CircleConst *&beta)
586 {
587   auto x = dynamic_cast<luci::CircleConst *>(node->x());
588   auto y = dynamic_cast<luci::CircleConst *>(node->y());
589
590   if (x != nullptr && y == nullptr)
591   {
592     pred_node = loco::must_cast<luci::CircleNode *>(node->y());
593     beta = x;
594   }
595   else if (x == nullptr && y != nullptr)
596   {
597     pred_node = loco::must_cast<luci::CircleNode *>(node->x());
598     beta = y;
599   }
600   else
601   {
602     // Ignore if ADD does not have a constant input.
603     return false;
604   }
605
606   if (pred_node->rank() != 4)
607     return false;
608
609   if (not broadcastable(beta, node))
610     return false;
611
612   beta = expand_to_rank_4(beta);
613
614   return true;
615 }
616
617 // We assume SUB with const input is NCHW if,
618 // Input shape: (N, C, H, W)
619 // Output shape: (N, C, H, W)
620 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
621 // 2. Input, Output, Const have the same C.
622 bool is_NCHW_with_const(const luci::CircleSub *node, const luci::CircleNode *pred_node,
623                         const luci::CircleConst *subtract)
624 {
625   assert(pred_node != nullptr);
626   assert(subtract != nullptr);
627
628   if (pred_node->rank() != 4)
629     return false;
630
631   const auto const_rank = subtract->rank();
632   // Support Rank 4 or scalar (rank 0 or 1)
633   if (const_rank != 4 && const_rank != 0 && const_rank != 1)
634     return false;
635
636   const auto input_cdim = pred_node->dim(1);
637   const auto output_cdim = node->dim(1);
638
639   if (const_rank == 4)
640   {
641     bool supported_shape = false;
642
643     // Check subtract is (1, C, 1, 1)
644     if (is_same_shape(subtract, {1, node->dim(1), 1, 1}))
645       supported_shape = true;
646
647     // Check subtract is (N, C, H, W)
648     if (is_same_shape(subtract, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
649       supported_shape = true;
650
651     return supported_shape;
652   }
653   if (input_cdim == output_cdim)
654     return true;
655   else
656     return false;
657 }
658
659 template <class T> bool convert_unary_features(T *node)
660 {
661   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->features());
662   auto pre_trans = create_pre_transpose(node);
663   pre_trans->a(pred_node);
664   node->features(pre_trans);
665
666   // Do shape inference for this node again.
667   node->shape_status(luci::ShapeStatus::UNDEFINED);
668
669   auto post_trans = create_post_transpose(node);
670   loco::replace(node).with(post_trans);
671
672   post_trans->a(node);
673
674   return true;
675 }
676
677 template <class T> bool convert_unary_x(T *node)
678 {
679   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
680   auto pre_trans = create_pre_transpose(node);
681   pre_trans->a(pred_node);
682   node->x(pre_trans);
683
684   // Do shape inference for this node again.
685   node->shape_status(luci::ShapeStatus::UNDEFINED);
686
687   auto post_trans = create_post_transpose(node);
688   loco::replace(node).with(post_trans);
689
690   post_trans->a(node);
691
692   return true;
693 }
694
695 template <class T> bool convert_unary_logits(T *node)
696 {
697   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->logits());
698   auto pre_trans = create_pre_transpose(node);
699   pre_trans->a(pred_node);
700   node->logits(pre_trans);
701
702   // Do shape inference for this node again.
703   node->shape_status(luci::ShapeStatus::UNDEFINED);
704
705   auto post_trans = create_post_transpose(node);
706   loco::replace(node).with(post_trans);
707
708   post_trans->a(node);
709
710   return true;
711 }
712
713 class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
714 {
715   // Default
716   bool visit(luci::CircleNode *node)
717   {
718     throw std::runtime_error(node->name() + " is an unsupported operator.");
719   }
720
721   bool visit(luci::CircleInput *node)
722   {
723     const auto n = node->dim(0);
724     const auto c = node->dim(1);
725     const auto h = node->dim(2);
726     const auto w = node->dim(3);
727
728     node->dim(1) = h;
729     node->dim(2) = w;
730     node->dim(3) = c;
731
732     // Do shape inference for this node again.
733     node->shape_status(luci::ShapeStatus::UNDEFINED);
734
735     // Insert post-tranpose
736     auto post_trans = create_post_transpose(node);
737     loco::replace(node).with(post_trans);
738
739     post_trans->a(node);
740
741     // Update graph input
742     auto graph_inputs = node->graph()->inputs();
743     auto graph_input = graph_inputs->at(node->index());
744     graph_input->shape({n, h, w, c});
745
746     return true;
747   }
748
749   bool visit(luci::CircleOutput *node)
750   {
751     // Insert pre-transpose
752     auto pre_trans = create_pre_transpose(node);
753     pre_trans->a(node->from());
754
755     node->from(pre_trans);
756
757     // Do shape inference for this node again.
758     node->shape_status(luci::ShapeStatus::UNDEFINED);
759
760     // Update graph output
761     const auto n = node->dim(0).value();
762     const auto c = node->dim(1).value();
763     const auto h = node->dim(2).value();
764     const auto w = node->dim(3).value();
765
766     auto graph_outputs = node->graph()->outputs();
767     auto graph_output = graph_outputs->at(node->index());
768     graph_output->shape({n, h, w, c});
769
770     return true;
771   }
772
773   bool visit(luci::CircleAdd *node)
774   {
775     luci::CircleNode *pred_node = nullptr;
776     luci::CircleConst *beta = nullptr;
777
778     if (is_NCHW_with_const(node, pred_node, beta))
779     {
780       assert(beta->rank() == 4); // FIX is_NCHW_with_const unless
781       auto nhwc_const = create_NHWC_from_NCHW(beta);
782       if (nhwc_const == nullptr)
783         return false;
784       node->y(nhwc_const);
785
786       auto pre_trans = create_pre_transpose(node);
787       pre_trans->a(pred_node);
788       node->x(pre_trans);
789     }
790     else if (beta == nullptr)
791     {
792       // Both inputs are not constant.
793       // In this case, we cannot distinguish NCHW from NHWC,
794       // so just insert Transpose Ops.
795       auto pre_trans_x = create_pre_transpose(node);
796       pre_trans_x->a(node->x());
797       node->x(pre_trans_x);
798
799       auto pre_trans_y = create_pre_transpose(node);
800       pre_trans_y->a(node->y());
801       node->y(pre_trans_y);
802     }
803     else
804     {
805       return false;
806     }
807
808     // Do shape inference for this node again.
809     node->shape_status(luci::ShapeStatus::UNDEFINED);
810
811     auto post_trans = create_post_transpose(node);
812     loco::replace(node).with(post_trans);
813
814     post_trans->a(node);
815     return true;
816   }
817
818   bool visit(luci::CircleConcatenation *node)
819   {
820     const auto num_values = node->numValues();
821     for (uint32_t i = 0; i < num_values; i++)
822     {
823       auto pred_node = loco::must_cast<luci::CircleNode *>(node->values(i));
824       auto pre_trans = create_pre_transpose(node);
825       pre_trans->a(pred_node);
826       node->values(i, pre_trans);
827     }
828
829     // Do shape inference for this node again.
830     node->shape_status(luci::ShapeStatus::UNDEFINED);
831
832     node->axis(nchw_axis_to_nhwc(node->axis()));
833
834     auto post_trans = create_post_transpose(node);
835     loco::replace(node).with(post_trans);
836
837     post_trans->a(node);
838
839     return true;
840   }
841
842   bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }
843
844   bool visit(luci::CircleGelu *node) { return convert_unary_features<luci::CircleGelu>(node); }
845
846   bool visit(luci::CircleLeakyRelu *node)
847   {
848     return convert_unary_features<luci::CircleLeakyRelu>(node);
849   }
850
851   bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
852
853   bool visit(luci::CircleMaximum *node)
854   {
855     if ((not is_const(node->x())) and is_scalar_const(node->y()))
856     {
857       auto pre_trans = create_pre_transpose(node);
858       pre_trans->a(node->x());
859       node->x(pre_trans);
860     }
861     else if (is_scalar_const(node->x()) and (not is_const(node->y())))
862     {
863       auto pre_trans = create_pre_transpose(node);
864       pre_trans->a(node->y());
865       node->y(pre_trans);
866     }
867     else if ((not is_const(node->x())) and (not is_const(node->y())))
868     {
869       auto pre_trans_x = create_pre_transpose(node);
870       pre_trans_x->a(node->x());
871       node->x(pre_trans_x);
872
873       auto pre_trans_y = create_pre_transpose(node);
874       pre_trans_y->a(node->y());
875       node->y(pre_trans_y);
876     }
877     else
878     {
879       // TODO support other cases
880       return false;
881     }
882
883     // Do shape inference for this node again.
884     node->shape_status(luci::ShapeStatus::UNDEFINED);
885
886     auto post_trans = create_post_transpose(node);
887     loco::replace(node).with(post_trans);
888
889     post_trans->a(node);
890     return true;
891   }
892
893   bool visit(luci::CircleMean *node)
894   {
895     auto input = loco::must_cast<luci::CircleNode *>(node->input());
896     if (input->rank() != 4)
897       return false;
898
899     auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
900     if (not rindices)
901       return false;
902
903     auto nhwc_rindices = create_NHWC_rindices(rindices);
904     if (not nhwc_rindices)
905       return false;
906
907     auto pre_trans = create_pre_transpose(node);
908     pre_trans->a(input);
909     node->input(pre_trans);
910
911     // Do shape inference for this node again.
912     node->shape_status(luci::ShapeStatus::UNDEFINED);
913
914     node->reduction_indices(nhwc_rindices);
915
916     if (node->keep_dims())
917     {
918       auto post_trans = create_post_transpose(node);
919       loco::replace(node).with(post_trans);
920
921       post_trans->a(node);
922
923       return true;
924     }
925
926     // node->keep_dims() == false
927     // 1D output never needs a transpose
928     if (node->rank() <= 1)
929       return true;
930
931     std::vector<bool> reduced_dims_nhwc(4, false);
932     uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
933
934     for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
935     {
936       reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
937     }
938
939     // if channel dimension has been reduced, we don't need a transpose
940     if (reduced_dims_nhwc[3])
941       return true;
942
943     // likewise, if both space dimensions are reduced, no transpose is needed
944     if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
945       return true;
946
947     std::vector<int32_t> post_trans_ind;
948     // case 1: only N is reduced
949     if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
950       post_trans_ind = {2, 0, 1};
951
952     // case 2: only H or W is reduced
953     if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
954       post_trans_ind = {0, 2, 1};
955
956     // case 3: N and either H or W are reduced
957     if (num_reduced_indices == 2)
958       post_trans_ind = {1, 0};
959
960     auto post_trans = create_Nd_transpose(node, post_trans_ind);
961     loco::replace(node).with(post_trans);
962
963     post_trans->a(node);
964
965     return true;
966   }
967
968   bool visit(luci::CircleMinimum *node)
969   {
970     if ((not is_const(node->x())) and is_scalar_const(node->y()))
971     {
972       auto pre_trans = create_pre_transpose(node);
973       pre_trans->a(node->x());
974       node->x(pre_trans);
975     }
976     else if (is_scalar_const(node->x()) and (not is_const(node->y())))
977     {
978       auto pre_trans = create_pre_transpose(node);
979       pre_trans->a(node->y());
980       node->y(pre_trans);
981     }
982     else
983     {
984       // TODO support other cases
985       return false;
986     }
987
988     // Do shape inference for this node again.
989     node->shape_status(luci::ShapeStatus::UNDEFINED);
990
991     auto post_trans = create_post_transpose(node);
992     loco::replace(node).with(post_trans);
993
994     post_trans->a(node);
995     return true;
996   }
997
998   bool visit(luci::CircleMul *node)
999   {
1000     LOGGER(l);
1001
1002     luci::CircleNode *pred_node = nullptr;
1003     luci::CircleConst *multiplier = nullptr;
1004
1005     if (is_NCHW_with_const(node, pred_node, multiplier))
1006     {
1007       assert(multiplier->rank() == 4); // FIX is_NCHW_with_const unless
1008       auto nhwc_const = create_NHWC_from_NCHW(multiplier);
1009       if (nhwc_const == nullptr)
1010         return false;
1011       node->y(nhwc_const);
1012
1013       auto pre_trans = create_pre_transpose(node);
1014       pre_trans->a(pred_node);
1015       node->x(pre_trans);
1016     }
1017     else if (multiplier == nullptr)
1018     {
1019       // Only support for input rank 4
1020       auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1021       if (input_x->rank() != 4)
1022         return false;
1023       auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1024       if (input_y->rank() != 4)
1025         return false;
1026
1027       auto pre_trans_x = create_pre_transpose(node);
1028       pre_trans_x->a(input_x);
1029       node->x(pre_trans_x);
1030
1031       auto pre_trans_y = create_pre_transpose(node);
1032       pre_trans_y->a(input_y);
1033       node->y(pre_trans_y);
1034     }
1035     else
1036     {
1037       return false;
1038     }
1039
1040     // Do shape inference for this node again.
1041     node->shape_status(luci::ShapeStatus::UNDEFINED);
1042
1043     auto post_trans = create_post_transpose(node);
1044     loco::replace(node).with(post_trans);
1045
1046     post_trans->a(node);
1047     return true;
1048   }
1049
1050   bool visit(luci::CircleNeg *node) { return convert_unary_x<luci::CircleNeg>(node); }
1051
1052   bool visit(luci::CirclePad *node)
1053   {
1054     if (!is_NCHW(node))
1055       return false;
1056
1057     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1058     auto pre_trans = create_pre_transpose(node);
1059     pre_trans->a(pred_node);
1060     node->input(pre_trans);
1061
1062     auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1063     const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1064     node->paddings(nhwc_paddings);
1065
1066     // Do shape inference for this node again.
1067     node->shape_status(luci::ShapeStatus::UNDEFINED);
1068
1069     auto post_trans = create_post_transpose(node);
1070     loco::replace(node).with(post_trans);
1071
1072     post_trans->a(node);
1073
1074     return true;
1075   }
1076
1077   bool visit(luci::CirclePadV2 *node)
1078   {
1079     if (!is_NCHW(node))
1080       return false;
1081
1082     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1083     auto pre_trans = create_pre_transpose(node);
1084     pre_trans->a(pred_node);
1085     node->input(pre_trans);
1086
1087     auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1088     const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1089     node->paddings(nhwc_paddings);
1090
1091     // Do shape inference for this node again.
1092     node->shape_status(luci::ShapeStatus::UNDEFINED);
1093
1094     auto post_trans = create_post_transpose(node);
1095     loco::replace(node).with(post_trans);
1096
1097     post_trans->a(node);
1098
1099     return true;
1100   }
1101
1102   // TODO Reduce duplicate code with CircleMean
1103   bool visit(luci::CircleReduceMax *node)
1104   {
1105     auto input = loco::must_cast<luci::CircleNode *>(node->input());
1106     if (input->rank() != 4)
1107       return false;
1108
1109     auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
1110     if (not rindices)
1111       return false;
1112
1113     auto nhwc_rindices = create_NHWC_rindices(rindices);
1114     if (not nhwc_rindices)
1115       return false;
1116
1117     auto pre_trans = create_pre_transpose(node);
1118     pre_trans->a(input);
1119     node->input(pre_trans);
1120
1121     // Do shape inference for this node again.
1122     node->shape_status(luci::ShapeStatus::UNDEFINED);
1123
1124     node->reduction_indices(nhwc_rindices);
1125
1126     if (node->keep_dims())
1127     {
1128       auto post_trans = create_post_transpose(node);
1129       loco::replace(node).with(post_trans);
1130
1131       post_trans->a(node);
1132
1133       return true;
1134     }
1135
1136     // The below codes handle the cases where node->keep_dims() == false
1137     // 1D output never needs a transpose
1138     if (node->rank() <= 1)
1139       return true;
1140
1141     std::vector<bool> reduced_dims_nhwc(4, false);
1142     uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
1143
1144     for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
1145     {
1146       reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
1147     }
1148
1149     // if channel dimension has been reduced, we don't need a transpose
1150     if (reduced_dims_nhwc[3])
1151       return true;
1152
1153     // likewise, if both space dimensions are reduced, no transpose is needed
1154     if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
1155       return true;
1156
1157     std::vector<int32_t> post_trans_ind;
1158     // case 1: only N is reduced
1159     if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
1160       post_trans_ind = {2, 0, 1};
1161
1162     // case 2: only H or W is reduced
1163     if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
1164       post_trans_ind = {0, 2, 1};
1165
1166     // case 3: N and either H or W are reduced
1167     if (num_reduced_indices == 2)
1168       post_trans_ind = {1, 0};
1169
1170     auto post_trans = create_Nd_transpose(node, post_trans_ind);
1171     loco::replace(node).with(post_trans);
1172
1173     post_trans->a(node);
1174
1175     return true;
1176   }
1177
1178   // TODO Reduce duplicate codes with CircleReduceMax
1179   bool visit(luci::CircleReduceMin *node)
1180   {
1181     auto input = loco::must_cast<luci::CircleNode *>(node->input());
1182     if (input->rank() != 4)
1183       return false;
1184
1185     auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
1186     if (not rindices)
1187       return false;
1188
1189     auto nhwc_rindices = create_NHWC_rindices(rindices);
1190     if (not nhwc_rindices)
1191       return false;
1192
1193     auto pre_trans = create_pre_transpose(node);
1194     pre_trans->a(input);
1195     node->input(pre_trans);
1196
1197     // Do shape inference for this node again.
1198     node->shape_status(luci::ShapeStatus::UNDEFINED);
1199
1200     node->reduction_indices(nhwc_rindices);
1201
1202     if (node->keep_dims())
1203     {
1204       auto post_trans = create_post_transpose(node);
1205       loco::replace(node).with(post_trans);
1206
1207       post_trans->a(node);
1208
1209       return true;
1210     }
1211
1212     // The below codes handle the cases where node->keep_dims() == false
1213     // 1D output never needs a transpose
1214     if (node->rank() <= 1)
1215       return true;
1216
1217     std::vector<bool> reduced_dims_nhwc(4, false);
1218     uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
1219
1220     for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
1221     {
1222       reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
1223     }
1224
1225     // if channel dimension has been reduced, we don't need a transpose
1226     if (reduced_dims_nhwc[3])
1227       return true;
1228
1229     // likewise, if both space dimensions are reduced, no transpose is needed
1230     if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
1231       return true;
1232
1233     std::vector<int32_t> post_trans_ind;
1234     // case 1: only N is reduced
1235     if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
1236       post_trans_ind = {2, 0, 1};
1237
1238     // case 2: only H or W is reduced
1239     if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
1240       post_trans_ind = {0, 2, 1};
1241
1242     // case 3: N and either H or W are reduced
1243     if (num_reduced_indices == 2)
1244       post_trans_ind = {1, 0};
1245
1246     auto post_trans = create_Nd_transpose(node, post_trans_ind);
1247     loco::replace(node).with(post_trans);
1248
1249     post_trans->a(node);
1250
1251     return true;
1252   }
1253
1254   bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
1255
1256   bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
1257
1258   bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
1259
1260   bool visit(luci::CircleSplitV *node)
1261   {
1262     // Change split dimension
1263     auto axis = dynamic_cast<luci::CircleConst *>(node->split_dim());
1264     if (not axis)
1265       return false;
1266
1267     if (axis->dtype() != loco::DataType::S32)
1268       return false;
1269
1270     if (axis->size<loco::DataType::S32>() != 1)
1271       return false;
1272
1273     axis->at<loco::DataType::S32>(0) = nchw_axis_to_nhwc(axis->at<loco::DataType::S32>(0));
1274
1275     // Insert pre-transpose
1276     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1277     auto pre_trans = create_pre_transpose(node);
1278     pre_trans->a(pred_node);
1279     node->input(pre_trans);
1280
1281     // Do shape inference for this node again.
1282     node->shape_status(luci::ShapeStatus::UNDEFINED);
1283
1284     // Insert post-transposes
1285     for (auto succ : loco::succs(node))
1286     {
1287       auto svo = loco::must_cast<luci::CircleSplitVOut *>(succ);
1288
1289       auto post_trans = create_post_transpose(svo);
1290       loco::replace(svo).with(post_trans);
1291       post_trans->a(svo);
1292     }
1293
1294     return true;
1295   }
1296
1297   bool visit(luci::CircleSquaredDifference *node)
1298   {
1299     // TODO support CircleConst input
1300     if (dynamic_cast<luci::CircleConst *>(node->x()) != nullptr)
1301       return false;
1302     if (dynamic_cast<luci::CircleConst *>(node->y()) != nullptr)
1303       return false;
1304
1305     auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1306     if (input_x->rank() != 4)
1307       return false;
1308     auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1309     if (input_y->rank() != 4)
1310       return false;
1311
1312     auto pre_trans_x = create_pre_transpose(node);
1313     pre_trans_x->a(input_x);
1314     node->x(pre_trans_x);
1315
1316     auto pre_trans_y = create_pre_transpose(node);
1317     pre_trans_y->a(input_y);
1318     node->y(pre_trans_y);
1319
1320     // Do shape inference for this node again.
1321     node->shape_status(luci::ShapeStatus::UNDEFINED);
1322
1323     auto post_trans = create_post_transpose(node);
1324     loco::replace(node).with(post_trans);
1325
1326     post_trans->a(node);
1327     return true;
1328   }
1329
1330   bool visit(luci::CircleSub *node)
1331   {
1332     luci::CircleNode *pred_node = nullptr;
1333     luci::CircleConst *subtract = nullptr;
1334
1335     auto const_x = dynamic_cast<luci::CircleConst *>(node->x());
1336     auto const_y = dynamic_cast<luci::CircleConst *>(node->y());
1337
1338     if (const_x != nullptr && const_y == nullptr)
1339     {
1340       // case of subtract - pred_node
1341       pred_node = loco::must_cast<luci::CircleNode *>(node->y());
1342       subtract = const_x;
1343
1344       if (!is_NCHW_with_const(node, pred_node, subtract))
1345         return false;
1346
1347       auto pre_trans = create_pre_transpose(node);
1348       pre_trans->a(pred_node);
1349
1350       if (subtract->rank() == 4)
1351       {
1352         auto nhwc_const = create_NHWC_from_NCHW(subtract);
1353         if (nhwc_const == nullptr)
1354           return false;
1355         node->x(nhwc_const);
1356       }
1357       node->y(pre_trans);
1358     }
1359     else if (const_x == nullptr && const_y != nullptr)
1360     {
1361       // case of pred_node - subtract
1362       pred_node = loco::must_cast<luci::CircleNode *>(node->x());
1363       subtract = const_y;
1364
1365       if (!is_NCHW_with_const(node, pred_node, subtract))
1366         return false;
1367
1368       auto pre_trans = create_pre_transpose(node);
1369       pre_trans->a(pred_node);
1370
1371       if (subtract->rank() == 4)
1372       {
1373         auto nhwc_const = create_NHWC_from_NCHW(subtract);
1374         if (nhwc_const == nullptr)
1375           return false;
1376         node->y(nhwc_const);
1377       }
1378
1379       node->x(pre_trans);
1380     }
1381     else if (const_x == nullptr && const_y == nullptr)
1382     {
1383       // Both inputs are not constant.
1384       // In this case, we cannot distinguish NCHW from NHWC,
1385       // so just insert Transpose Ops.
1386       // Only support for input rank 4.
1387       auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1388       if (input_x->rank() != 4)
1389         return false;
1390       auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1391       if (input_y->rank() != 4)
1392         return false;
1393
1394       auto pre_trans_x = create_pre_transpose(node);
1395       pre_trans_x->a(input_x);
1396       node->x(pre_trans_x);
1397
1398       auto pre_trans_y = create_pre_transpose(node);
1399       pre_trans_y->a(input_y);
1400       node->y(pre_trans_y);
1401     }
1402
1403     // Do shape inference for this node again.
1404     node->shape_status(luci::ShapeStatus::UNDEFINED);
1405
1406     auto post_trans = create_post_transpose(node);
1407     loco::replace(node).with(post_trans);
1408
1409     post_trans->a(node);
1410     return true;
1411   }
1412 };
1413
1414 } // namespace
1415
1416 namespace luci
1417 {
1418
1419 bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
1420 {
1421   LOGGER(l);
1422   INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl;
1423
1424   // Annotate NHWC operators
1425   // NHWC operators are detected by pattern matching
1426   //
1427   // Pattern
1428   //    pre-Transose (or pre-Reshape) + [intermediate Ops] + post-Transpose (or post-Reshape)
1429   //
1430   // [intermediate Ops] are annotated as NHWC
1431   //
1432   // NOTE A single pre-Transpose/Reshape can have multiple post-Transpose/Reshape.
1433   // For example,
1434   // pre-Transpose --- [intermediate Ops] --- post-Transpose
1435   //                |
1436   //                +--[intermediate Ops] --- post-Transpose
1437   //
1438   // NOTE Intermediate Ops SHOULD NOT contain pre-Transpose/Reshape
1439   for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
1440   {
1441     if (has_data_format(node))
1442       continue;
1443
1444     if (is_pre_transpose(node) || is_pre_reshape(node))
1445     {
1446       std::set<loco::Node *> intermediate;
1447
1448       // Variable to check intermediate Ops contain pre-Transpose/Reshape
1449       bool has_pre = false;
1450
1451       // Variable to check the pattern is closed with post-Transpose/Reshape
1452       bool is_closed = true;
1453
1454       // For recursive call of lambda
1455       std::function<void(loco::Node *)> collect_intermediate;
1456       collect_intermediate = [&](loco::Node *n) {
1457         for (auto succ : loco::succs(n))
1458         {
1459           // Skip unnecessary traversal
1460           if (intermediate.find(succ) != intermediate.end())
1461             continue;
1462
1463           // Exit condition
1464           if (is_post_transpose(succ) || is_post_reshape(succ))
1465             continue;
1466
1467           if (is_pre_transpose(succ) || is_pre_reshape(succ))
1468           {
1469             has_pre = true;
1470             break;
1471           }
1472
1473           if (is_output(succ))
1474           {
1475             is_closed = false;
1476             break;
1477           }
1478
1479           intermediate.emplace(succ);
1480
1481           collect_intermediate(succ);
1482         }
1483       };
1484
1485       collect_intermediate(node);
1486
1487       if (has_pre or not is_closed)
1488         continue;
1489
1490       for (auto inter : intermediate)
1491       {
1492         if (not has_data_format(inter))
1493           set_data_format(inter, DataFormat::NHWC);
1494       }
1495     }
1496   }
1497
1498   // Annotate NCHW operators
1499   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1500   {
1501     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1502     switch (circle_node->opcode())
1503     {
1504       // List of supported Ops
1505       case luci::CircleOpcode::CIRCLEINPUT:
1506         if (!_preserve_input && !has_data_format(node))
1507         {
1508           set_data_format(node, DataFormat::NCHW);
1509         }
1510         break;
1511       case luci::CircleOpcode::CIRCLEOUTPUT:
1512         if (!_preserve_output && !has_data_format(node))
1513         {
1514           set_data_format(node, DataFormat::NCHW);
1515         }
1516         break;
1517       // SOFTMAX, LOG_SOFTMAX are not converted, because
1518       // tflite/circle assumes the last channel is always axis
1519       case luci::CircleOpcode::ADD:
1520       case luci::CircleOpcode::CONCATENATION:
1521       case luci::CircleOpcode::ELU:
1522       case luci::CircleOpcode::GELU:
1523       case luci::CircleOpcode::LEAKY_RELU:
1524       case luci::CircleOpcode::LOGISTIC:
1525       case luci::CircleOpcode::MAXIMUM:
1526       case luci::CircleOpcode::MEAN:
1527       case luci::CircleOpcode::MINIMUM:
1528       case luci::CircleOpcode::MUL:
1529       case luci::CircleOpcode::NEG:
1530       case luci::CircleOpcode::PAD:
1531       case luci::CircleOpcode::PADV2:
1532       case luci::CircleOpcode::REDUCE_MAX:
1533       case luci::CircleOpcode::REDUCE_MIN:
1534       case luci::CircleOpcode::RELU:
1535       case luci::CircleOpcode::RELU6:
1536       case luci::CircleOpcode::RSQRT:
1537       case luci::CircleOpcode::SPLIT_V:
1538       case luci::CircleOpcode::SQUARED_DIFFERENCE:
1539       case luci::CircleOpcode::SUB:
1540         if (!has_data_format(node))
1541         {
1542           set_data_format(node, DataFormat::NCHW);
1543         }
1544         break;
1545       default:
1546         break;
1547     }
1548   }
1549
1550   bool changed = false;
1551   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1552   {
1553     if (!has_data_format(node))
1554     {
1555       // Unsupported Op
1556       continue;
1557     }
1558     else if (get_data_format(node) == DataFormat::NHWC)
1559     {
1560       // Already converted to NHWC
1561       continue;
1562     }
1563     else if (has_dynamic_shape(node))
1564     {
1565       // This pass only works for static-shaped node
1566       INFO(l) << "Skip the node with a dynamic shape." << std::endl;
1567       continue;
1568     }
1569     else
1570     {
1571       ConvertNCHWToNHWC converter;
1572       auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1573       if (circle_node->rank() != 4)
1574       {
1575         // TODO replace the check above with the input rank check, and remove the condition below
1576         if (not dynamic_cast<luci::CircleMean *>(node) and
1577             not dynamic_cast<luci::CircleReduceMax *>(node) and
1578             not dynamic_cast<luci::CircleReduceMin *>(node))
1579           continue;
1580       }
1581
1582       if (circle_node->accept(&converter))
1583       {
1584         set_data_format(node, DataFormat::NHWC);
1585         changed = true;
1586       }
1587       else
1588       {
1589         continue;
1590       }
1591     }
1592   }
1593
1594   INFO(l) << "ConvertNCHWToNHWCPass End" << std::endl;
1595   return changed;
1596 }
1597
1598 } // namespace luci