99e1e293990fd3d104c9e1449aafa94ea3605b1c
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ConvertNCHWToNHWCPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
18 #include "CircleOptimizerUtils.h"
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/Service/Nodes/CircleConst.h>
24 #include <luci/Log.h>
25
26 #include <functional>
27
28 namespace
29 {
30
31 // Return true if from can be broadcasted to to
32 // to's shape is [N, C, H, W]
33 bool broadcastable(const luci::CircleConst *from, const luci::CircleNode *to)
34 {
35   assert(to->rank() == 4); // FIX_CALLER_UNLESS
36
37   const auto from_rank = from->rank();
38   if (from_rank > 4)
39     return false;
40
41   // Scalar is always broadcastable
42   if (from_rank == 0)
43     return true;
44
45   for (uint32_t i = 1; i <= from_rank; i++)
46   {
47     auto to_index = 4 - i;
48     auto from_index = from_rank - i;
49
50     if (from->dim(from_index).value() != to->dim(to_index).value() and
51         from->dim(from_index).value() != 1)
52       return false;
53   }
54
55   return true;
56 }
57
58 // Expand node to rank 4
59 // node should have rank less than or equal to 4
60 void expand_to_rank_4(luci::CircleConst *node)
61 {
62   auto original_rank = node->rank();
63
64   assert(original_rank <= 4); // FIX_CALLER_UNLESS
65
66   if (original_rank == 4)
67     return;
68
69   std::vector<uint32_t> original_shape;
70   for (uint32_t i = 0; i < original_rank; i++)
71   {
72     original_shape.emplace_back(node->dim(i).value());
73   }
74
75   node->rank(4);
76   for (uint32_t i = 0; i < (4 - original_rank); i++)
77     node->dim(i) = 1;
78
79   for (uint32_t i = 0; i < original_rank; i++)
80     node->dim(i + (4 - original_rank)) = original_shape.at(i);
81 }
82
83 bool is_output(const loco::Node *node)
84 {
85   auto cnode = loco::must_cast<const luci::CircleNode *>(node);
86   auto opcode = cnode->opcode();
87   if (opcode == luci::CircleOpcode::CIRCLEOUTPUT ||
88       opcode == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
89     return true;
90
91   return false;
92 }
93
94 bool is_same_shape(const luci::CircleNode *node, const std::vector<loco::Dimension> &shape)
95 {
96   if (not node)
97     return false;
98
99   if (shape.size() != node->rank())
100     return false;
101
102   for (uint32_t i = 0; i < shape.size(); i++)
103   {
104     if (not(node->dim(i) == shape[i]))
105       return false;
106   }
107   return true;
108 }
109
110 enum class DataFormat
111 {
112   NCHW,
113   NHWC
114 };
115
116 /**
117  * @brief Set annotation for DataFormat (NCHW, NHWC)
118  *
119  * @note DataFormatAnnotation will live longer than this Pass (until the
120  *       annotated loco::Node is erased). So, do not use large data in the
121  *       annotation to avoid excessive memory usage.
122  */
123 class DataFormatAnnotation final : public loco::NodeAnnotation
124 {
125 public:
126   DataFormatAnnotation(const DataFormat &format) : _format{format}
127   {
128     // DO NOTHING
129   }
130
131 public:
132   const DataFormat &format(void) const { return _format; }
133
134 private:
135   DataFormat _format;
136 };
137
138 void set_data_format(loco::Node *node, const DataFormat &format)
139 {
140   node->annot(std::make_unique<DataFormatAnnotation>(format));
141 }
142
143 DataFormat get_data_format(loco::Node *node)
144 {
145   assert(node->annot<DataFormatAnnotation>() != nullptr);
146   return node->annot<DataFormatAnnotation>()->format();
147 }
148
149 bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; }
150
151 bool check_4d_transpose(loco::Node *node, const std::vector<int32_t> indices)
152 {
153   assert(indices.size() == 4);
154
155   auto trans = dynamic_cast<luci::CircleTranspose *>(node);
156   if (not trans)
157     return false;
158
159   if (not trans->perm())
160     return false;
161
162   auto perm = dynamic_cast<luci::CircleConst *>(trans->perm());
163   // Only const perm is supported
164   if (not perm)
165     return false;
166
167   if (perm->dtype() != loco::DataType::S32)
168     return false;
169
170   if (perm->size<loco::DataType::S32>() != 4)
171     return false;
172
173   for (uint32_t i = 0; i < 4; i++)
174   {
175     if (perm->at<loco::DataType::S32>(i) != indices[i])
176       return false;
177   }
178
179   return true;
180 }
181
182 luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
183                                            const std::vector<int32_t> indices)
184 {
185   assert(indices.size() == 4);
186
187   auto name = node->name();
188   assert(name.length() > 0);
189
190   auto perm = node->graph()->nodes()->create<luci::CircleConst>();
191   perm->dtype(loco::DataType::S32);
192   perm->size<loco::DataType::S32>(4);
193   perm->rank(1);
194   perm->dim(0) = 4;
195   for (uint32_t i = 0; i < 4; i++)
196     perm->at<loco::DataType::S32>(i) = indices[i];
197   perm->shape_status(luci::ShapeStatus::VALID);
198
199   auto make_string = [](const std::vector<int32_t> &nums) {
200     std::string str;
201     for (auto num : nums)
202     {
203       if (str.length() > 0)
204         str += ".";
205       str += std::to_string(num);
206     }
207     return str;
208   };
209
210   auto str_indices = make_string(indices);
211
212   perm->name(name + "/Transpose_" + str_indices + "/perm");
213
214   auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
215   trans->perm(perm);
216   trans->name(name + "/Transpose_" + str_indices);
217   luci::add_origin(trans, luci::get_origin(node));
218
219   return trans;
220 }
221
222 luci::CircleTranspose *create_Nd_transpose(luci::CircleNode *node,
223                                            const std::vector<int32_t> indices)
224 {
225   auto name = node->name();
226   assert(name.length() > 0);
227
228   auto perm = node->graph()->nodes()->create<luci::CircleConst>();
229   perm->dtype(loco::DataType::S32);
230   perm->size<loco::DataType::S32>(indices.size());
231   perm->rank(1);
232   perm->dim(0) = indices.size();
233   for (uint32_t i = 0; i < indices.size(); i++)
234     perm->at<loco::DataType::S32>(i) = indices[i];
235   perm->shape_status(luci::ShapeStatus::VALID);
236
237   auto make_string = [](const std::vector<int32_t> &nums) {
238     std::string str;
239     for (auto num : nums)
240     {
241       if (str.length() > 0)
242         str += ".";
243       str += std::to_string(num);
244     }
245     return str;
246   };
247
248   auto str_indices = make_string(indices);
249
250   perm->name(name + "/Transpose_" + str_indices + "/perm");
251
252   auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
253   trans->perm(perm);
254   trans->name(name + "/Transpose_" + str_indices);
255   luci::add_origin(trans, luci::get_origin(node));
256
257   return trans;
258 }
259
260 int32_t nchw_axis_to_nhwc(int32_t axis)
261 {
262   uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4);
263   static const uint32_t to_nhwc[4] = {0, 3, 1, 2};
264   if (pos_axis > 3)
265     throw std::runtime_error("Concat axis must be in range [-4, 4)");
266   return to_nhwc[pos_axis];
267 }
268
269 luci::CircleTranspose *create_post_transpose(luci::CircleNode *node)
270 {
271   return create_4d_transpose(node, {0, 3, 1, 2});
272 }
273
274 luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node)
275 {
276   return create_4d_transpose(node, {0, 2, 3, 1});
277 }
278
279 bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
280 {
281   assert(indices.size() == 4); // FIX_CALLER_UNLESS
282
283   auto reshape = dynamic_cast<luci::CircleReshape *>(node);
284   if (not reshape)
285     return false;
286
287   if (reshape->rank() != 4)
288     return false;
289
290   auto input = loco::must_cast<luci::CircleNode *>(reshape->tensor());
291   if (input->shape_status() != luci::ShapeStatus::VALID)
292     return false;
293
294   if (input->rank() != 4)
295     return false;
296
297   if (reshape->shape_status() != luci::ShapeStatus::VALID)
298     return false;
299
300   if (!(input->dim(0) == reshape->dim(indices[0])) ||
301       !(input->dim(1) == reshape->dim(indices[1])) ||
302       !(input->dim(2) == reshape->dim(indices[2])) || !(input->dim(3) == reshape->dim(indices[3])))
303     return false;
304
305   return true;
306 }
307
308 // Check if Reshape that converts NCHW -> NHWC
309 bool is_pre_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 3, 1, 2}); }
310
311 // Check if Reshape that converts NHWC -> NCHW
312 bool is_post_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 2, 3, 1}); }
313
314 bool is_post_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 3, 1, 2}); }
315
316 bool is_pre_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 2, 3, 1}); }
317
318 uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices)
319 {
320   return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
321            dimension.dim(3).value() +
322          indices[1] * dimension.dim(2).value() * dimension.dim(3).value() +
323          indices[2] * dimension.dim(3).value() + indices[3];
324 }
325
326 luci::CircleConst *create_NHWC_paddings(luci::CircleConst *paddings)
327 {
328   // paddings shape is (4,2) (it was checked by is_NCHW)
329   assert(paddings != nullptr);
330   assert(paddings->rank() == 2);
331   assert(paddings->dim(0).value() == 4);
332   assert(paddings->dim(1).value() == 2);
333
334   // paddings for idx 0~3 are 0 (checked by is_NCHW)
335   assert(paddings->at<loco::DataType::S32>(0) == 0);
336   assert(paddings->at<loco::DataType::S32>(1) == 0);
337   assert(paddings->at<loco::DataType::S32>(2) == 0);
338   assert(paddings->at<loco::DataType::S32>(3) == 0);
339
340   auto name = paddings->name();
341   assert(name.length() > 0);
342
343   auto nhwc_paddings = paddings->graph()->nodes()->create<luci::CircleConst>();
344   nhwc_paddings->dtype(loco::DataType::S32);
345   nhwc_paddings->shape({4, 2});
346   nhwc_paddings->shape_status(luci::ShapeStatus::VALID);
347   nhwc_paddings->size<loco::DataType::S32>(4 * 2);
348   nhwc_paddings->name(name + "_NHWC");
349
350   for (uint32_t dim = 0; dim < 4; dim++)
351   {
352     for (uint32_t i = 0; i < 2; i++)
353     {
354       int32_t data = 0;
355
356       if (dim == 1)
357       {
358         // get third dimension (H in NCHW)
359         data = paddings->at<loco::DataType::S32>(2 * 2 + i);
360       }
361       else if (dim == 2)
362       {
363         // get fourth dimension (W in NCHW)
364         data = paddings->at<loco::DataType::S32>(3 * 2 + i);
365       }
366
367       nhwc_paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
368     }
369   }
370   return nhwc_paddings;
371 }
372
373 luci::CircleConst *create_NHWC_rindices(luci::CircleConst *rindices)
374 {
375   assert(rindices != nullptr); // FIX_CALLER_UNLESS
376
377   if (rindices->dtype() != loco::DataType::S32)
378     return nullptr;
379
380   auto nhwc_rindices = luci::clone(rindices);
381   auto name = rindices->name();
382   assert(name.length() > 0); // FIX_CALLER_UNLESS
383   nhwc_rindices->name(name + "_NHWC");
384
385   auto size = nhwc_rindices->size<loco::DataType::S32>();
386   for (uint32_t i = 0; i < size; i++)
387   {
388     nhwc_rindices->at<loco::DataType::S32>(i) =
389       nchw_axis_to_nhwc(rindices->at<loco::DataType::S32>(i));
390   }
391
392   return nhwc_rindices;
393 }
394
395 luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant)
396 {
397   LOGGER(l);
398   assert(constant->rank() == 4);
399
400   // TODO: Support non-float types
401   if (constant->dtype() != loco::DataType::FLOAT32)
402   {
403     INFO(l) << "Non-float type constant: " << constant->name() << std::endl;
404     return nullptr;
405   }
406
407   loco::TensorShape nchw_dimension{constant->dim(0), constant->dim(1), constant->dim(2),
408                                    constant->dim(3)};
409   loco::TensorShape nhwc_dimension{constant->dim(0), constant->dim(2), constant->dim(3),
410                                    constant->dim(1)};
411
412   auto name = constant->name();
413   assert(name.length() > 0);
414
415   auto nhwc_const = constant->graph()->nodes()->create<luci::CircleConst>();
416   nhwc_const->dtype(constant->dtype());
417   nhwc_const->rank(4);
418   nhwc_const->dim(0).set(constant->dim(0).value());
419   nhwc_const->dim(1).set(constant->dim(2).value());
420   nhwc_const->dim(2).set(constant->dim(3).value());
421   nhwc_const->dim(3).set(constant->dim(1).value());
422   nhwc_const->shape_status(luci::ShapeStatus::VALID);
423   nhwc_const->size<loco::DataType::FLOAT32>(constant->size<loco::DataType::FLOAT32>());
424   nhwc_const->name(name + "_NHWC");
425
426   for (uint32_t n = 0; n < nchw_dimension.dim(0).value(); n++)
427   {
428     for (uint32_t c = 0; c < nchw_dimension.dim(1).value(); c++)
429     {
430       for (uint32_t h = 0; h < nchw_dimension.dim(2).value(); h++)
431       {
432         for (uint32_t w = 0; w < nchw_dimension.dim(3).value(); w++)
433         {
434           uint32_t nchw_indices[4] = {n, c, h, w};
435           uint32_t nhwc_indices[4] = {n, h, w, c};
436           auto data =
437             constant->at<loco::DataType::FLOAT32>(cal_offset(nchw_dimension, nchw_indices));
438           nhwc_const->at<loco::DataType::FLOAT32>(cal_offset(nhwc_dimension, nhwc_indices)) = data;
439         }
440       }
441     }
442   }
443   return nhwc_const;
444 }
445
446 // NOTE Following conditions can be extended later
447 //
448 // Find PAD with an NCHW pattern described below
449 //   - Paddings shape : [4, 2]
450 //   - Paddings value : [[0, 0], [0, 0], [h_t, h_b], [w_t, w_b]]]
451 bool is_NCHW(const luci::CirclePad *node)
452 {
453   const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
454   // Non-const paddings is not supported
455   if (paddings == nullptr)
456     return false;
457
458   if (paddings->rank() != 2)
459     return false;
460
461   if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
462     return false;
463
464   // Only check the first two dimensions
465   for (uint32_t dim = 0; dim < 2; dim++)
466   {
467     for (uint32_t i = 0; i < 2; i++)
468     {
469       auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
470       if (data != 0)
471         return false;
472     }
473   }
474
475   return true;
476 }
477
478 // NOTE Copied from is_NCHW(CirclePad)
479 bool is_NCHW(const luci::CirclePadV2 *node)
480 {
481   const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
482   // Non-const paddings is not supported
483   if (paddings == nullptr)
484     return false;
485
486   if (paddings->rank() != 2)
487     return false;
488
489   if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
490     return false;
491
492   // Only check the first two dimensions
493   for (uint32_t dim = 0; dim < 2; dim++)
494   {
495     for (uint32_t i = 0; i < 2; i++)
496     {
497       auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
498       if (data != 0)
499         return false;
500     }
501   }
502
503   return true;
504 }
505
506 bool is_const(const loco::Node *node)
507 {
508   if (not dynamic_cast<const luci::CircleConst *>(node))
509     return false;
510
511   return true;
512 }
513
514 bool is_scalar_const(const loco::Node *node)
515 {
516   auto const_node = dynamic_cast<const luci::CircleConst *>(node);
517   if (not const_node)
518     return false;
519
520   const auto const_rank = const_node->rank();
521   // shape of scalar
522   // 1. rank = 0
523   // 2. rank = 1, dimension = 1
524   if (const_rank == 0)
525     return true;
526
527   if (const_rank == 1 && const_node->dim(0).value() == 1)
528     return true;
529
530   return false;
531 }
532
533 // NOTE Following conditions can be extended later
534 //
535 // Find MUL with an NCHW pattern described below
536 //   - Input (non-constant) shape : [N, C, H, W]
537 //   - Input (constant) shape : broadcastable to [N, C, H, W]
538 //   - Output shape : [N, C, H, W]
539 bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
540                         luci::CircleConst *&multiplier)
541 {
542   auto x = dynamic_cast<luci::CircleConst *>(node->x());
543   auto y = dynamic_cast<luci::CircleConst *>(node->y());
544
545   if (x != nullptr && y == nullptr)
546   {
547     pred_node = loco::must_cast<luci::CircleNode *>(node->y());
548     multiplier = x;
549   }
550   else if (x == nullptr && y != nullptr)
551   {
552     pred_node = loco::must_cast<luci::CircleNode *>(node->x());
553     multiplier = y;
554   }
555   else
556   {
557     // Ignore if MUL does not have a multiplier input.
558     return false;
559   }
560
561   if (pred_node->rank() != 4)
562     return false;
563
564   if (not broadcastable(multiplier, node))
565     return false;
566
567   expand_to_rank_4(multiplier);
568
569   return true;
570 }
571
572 // We assume ADD with const input is NCHW if,
573 // Input shape: (N, C, H, W)
574 // Output shape: (N, C, H, W)
575 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
576 // 2. Input, Output, Const have the same C.
577 bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node,
578                         luci::CircleConst *&beta)
579 {
580   auto x = dynamic_cast<luci::CircleConst *>(node->x());
581   auto y = dynamic_cast<luci::CircleConst *>(node->y());
582
583   if (x != nullptr && y == nullptr)
584   {
585     pred_node = loco::must_cast<luci::CircleNode *>(node->y());
586     beta = x;
587   }
588   else if (x == nullptr && y != nullptr)
589   {
590     pred_node = loco::must_cast<luci::CircleNode *>(node->x());
591     beta = y;
592   }
593   else
594   {
595     // Ignore if ADD does not have a constant input.
596     return false;
597   }
598
599   if (pred_node->rank() != 4)
600     return false;
601
602   if (not broadcastable(beta, node))
603     return false;
604
605   expand_to_rank_4(beta);
606
607   return true;
608 }
609
610 // We assume SUB with const input is NCHW if,
611 // Input shape: (N, C, H, W)
612 // Output shape: (N, C, H, W)
613 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
614 // 2. Input, Output, Const have the same C.
615 bool is_NCHW_with_const(const luci::CircleSub *node, const luci::CircleNode *pred_node,
616                         const luci::CircleConst *subtract)
617 {
618   assert(pred_node != nullptr);
619   assert(subtract != nullptr);
620
621   if (pred_node->rank() != 4)
622     return false;
623
624   const auto const_rank = subtract->rank();
625   // Support Rank 4 or scalar (rank 0 or 1)
626   if (const_rank != 4 && const_rank != 0 && const_rank != 1)
627     return false;
628
629   const auto input_cdim = pred_node->dim(1);
630   const auto output_cdim = node->dim(1);
631
632   if (const_rank == 4)
633   {
634     bool supported_shape = false;
635
636     // Check subtract is (1, C, 1, 1)
637     if (is_same_shape(subtract, {1, node->dim(1), 1, 1}))
638       supported_shape = true;
639
640     // Check subtract is (N, C, H, W)
641     if (is_same_shape(subtract, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
642       supported_shape = true;
643
644     return supported_shape;
645   }
646   if (input_cdim == output_cdim)
647     return true;
648   else
649     return false;
650 }
651
652 template <class T> bool convert_unary_features(T *node)
653 {
654   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->features());
655   auto pre_trans = create_pre_transpose(node);
656   pre_trans->a(pred_node);
657   node->features(pre_trans);
658
659   // Do shape inference for this node again.
660   node->shape_status(luci::ShapeStatus::UNDEFINED);
661
662   auto post_trans = create_post_transpose(node);
663   loco::replace(node).with(post_trans);
664
665   post_trans->a(node);
666
667   return true;
668 }
669
670 template <class T> bool convert_unary_x(T *node)
671 {
672   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
673   auto pre_trans = create_pre_transpose(node);
674   pre_trans->a(pred_node);
675   node->x(pre_trans);
676
677   // Do shape inference for this node again.
678   node->shape_status(luci::ShapeStatus::UNDEFINED);
679
680   auto post_trans = create_post_transpose(node);
681   loco::replace(node).with(post_trans);
682
683   post_trans->a(node);
684
685   return true;
686 }
687
688 template <class T> bool convert_unary_logits(T *node)
689 {
690   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->logits());
691   auto pre_trans = create_pre_transpose(node);
692   pre_trans->a(pred_node);
693   node->logits(pre_trans);
694
695   // Do shape inference for this node again.
696   node->shape_status(luci::ShapeStatus::UNDEFINED);
697
698   auto post_trans = create_post_transpose(node);
699   loco::replace(node).with(post_trans);
700
701   post_trans->a(node);
702
703   return true;
704 }
705
706 class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
707 {
708   // Default
709   bool visit(luci::CircleNode *node)
710   {
711     throw std::runtime_error(node->name() + " is an unsupported operator.");
712   }
713
714   bool visit(luci::CircleInput *node)
715   {
716     const auto n = node->dim(0);
717     const auto c = node->dim(1);
718     const auto h = node->dim(2);
719     const auto w = node->dim(3);
720
721     node->dim(1) = h;
722     node->dim(2) = w;
723     node->dim(3) = c;
724
725     // Do shape inference for this node again.
726     node->shape_status(luci::ShapeStatus::UNDEFINED);
727
728     // Insert post-tranpose
729     auto post_trans = create_post_transpose(node);
730     loco::replace(node).with(post_trans);
731
732     post_trans->a(node);
733
734     // Update graph input
735     auto graph_inputs = node->graph()->inputs();
736     auto graph_input = graph_inputs->at(node->index());
737     graph_input->shape({n, h, w, c});
738
739     return true;
740   }
741
742   bool visit(luci::CircleOutput *node)
743   {
744     // Insert pre-transpose
745     auto pre_trans = create_pre_transpose(node);
746     pre_trans->a(node->from());
747
748     node->from(pre_trans);
749
750     // Do shape inference for this node again.
751     node->shape_status(luci::ShapeStatus::UNDEFINED);
752
753     // Update graph output
754     const auto n = node->dim(0).value();
755     const auto c = node->dim(1).value();
756     const auto h = node->dim(2).value();
757     const auto w = node->dim(3).value();
758
759     auto graph_outputs = node->graph()->outputs();
760     auto graph_output = graph_outputs->at(node->index());
761     graph_output->shape({n, h, w, c});
762
763     return true;
764   }
765
766   bool visit(luci::CircleAdd *node)
767   {
768     luci::CircleNode *pred_node = nullptr;
769     luci::CircleConst *beta = nullptr;
770
771     if (is_NCHW_with_const(node, pred_node, beta))
772     {
773       assert(beta->rank() == 4); // FIX is_NCHW_with_const unless
774       auto nhwc_const = create_NHWC_from_NCHW(beta);
775       if (nhwc_const == nullptr)
776         return false;
777       node->y(nhwc_const);
778
779       auto pre_trans = create_pre_transpose(node);
780       pre_trans->a(pred_node);
781       node->x(pre_trans);
782     }
783     else if (beta == nullptr)
784     {
785       // Both inputs are not constant.
786       // In this case, we cannot distinguish NCHW from NHWC,
787       // so just insert Transpose Ops.
788       auto pre_trans_x = create_pre_transpose(node);
789       pre_trans_x->a(node->x());
790       node->x(pre_trans_x);
791
792       auto pre_trans_y = create_pre_transpose(node);
793       pre_trans_y->a(node->y());
794       node->y(pre_trans_y);
795     }
796     else
797     {
798       return false;
799     }
800
801     // Do shape inference for this node again.
802     node->shape_status(luci::ShapeStatus::UNDEFINED);
803
804     auto post_trans = create_post_transpose(node);
805     loco::replace(node).with(post_trans);
806
807     post_trans->a(node);
808     return true;
809   }
810
811   bool visit(luci::CircleConcatenation *node)
812   {
813     const auto num_values = node->numValues();
814     for (uint32_t i = 0; i < num_values; i++)
815     {
816       auto pred_node = loco::must_cast<luci::CircleNode *>(node->values(i));
817       auto pre_trans = create_pre_transpose(node);
818       pre_trans->a(pred_node);
819       node->values(i, pre_trans);
820     }
821
822     // Do shape inference for this node again.
823     node->shape_status(luci::ShapeStatus::UNDEFINED);
824
825     node->axis(nchw_axis_to_nhwc(node->axis()));
826
827     auto post_trans = create_post_transpose(node);
828     loco::replace(node).with(post_trans);
829
830     post_trans->a(node);
831
832     return true;
833   }
834
835   bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }
836
837   bool visit(luci::CircleLeakyRelu *node)
838   {
839     return convert_unary_features<luci::CircleLeakyRelu>(node);
840   }
841
842   bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
843
844   bool visit(luci::CircleMaximum *node)
845   {
846     if ((not is_const(node->x())) and is_scalar_const(node->y()))
847     {
848       auto pre_trans = create_pre_transpose(node);
849       pre_trans->a(node->x());
850       node->x(pre_trans);
851     }
852     else if (is_scalar_const(node->x()) and (not is_const(node->y())))
853     {
854       auto pre_trans = create_pre_transpose(node);
855       pre_trans->a(node->y());
856       node->y(pre_trans);
857     }
858     else if ((not is_const(node->x())) and (not is_const(node->y())))
859     {
860       auto pre_trans_x = create_pre_transpose(node);
861       pre_trans_x->a(node->x());
862       node->x(pre_trans_x);
863
864       auto pre_trans_y = create_pre_transpose(node);
865       pre_trans_y->a(node->y());
866       node->y(pre_trans_y);
867     }
868     else
869     {
870       // TODO support other cases
871       return false;
872     }
873
874     // Do shape inference for this node again.
875     node->shape_status(luci::ShapeStatus::UNDEFINED);
876
877     auto post_trans = create_post_transpose(node);
878     loco::replace(node).with(post_trans);
879
880     post_trans->a(node);
881     return true;
882   }
883
884   bool visit(luci::CircleMean *node)
885   {
886     auto input = loco::must_cast<luci::CircleNode *>(node->input());
887     if (input->rank() != 4)
888       return false;
889
890     auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
891     if (not rindices)
892       return false;
893
894     auto nhwc_rindices = create_NHWC_rindices(rindices);
895     if (not nhwc_rindices)
896       return false;
897
898     auto pre_trans = create_pre_transpose(node);
899     pre_trans->a(input);
900     node->input(pre_trans);
901
902     // Do shape inference for this node again.
903     node->shape_status(luci::ShapeStatus::UNDEFINED);
904
905     node->reduction_indices(nhwc_rindices);
906
907     if (node->keep_dims())
908     {
909       auto post_trans = create_post_transpose(node);
910       loco::replace(node).with(post_trans);
911
912       post_trans->a(node);
913
914       return true;
915     }
916
917     // node->keep_dims() == false
918     // 1D output never needs a transpose
919     if (node->rank() <= 1)
920       return true;
921
922     std::vector<bool> reduced_dims_nhwc(4, false);
923     uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
924
925     for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
926     {
927       reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
928     }
929
930     // if channel dimension has been reduced, we don't need a transpose
931     if (reduced_dims_nhwc[3])
932       return true;
933
934     // likewise, if both space dimensions are reduced, no transpose is needed
935     if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
936       return true;
937
938     std::vector<int32_t> post_trans_ind;
939     // case 1: only N is reduced
940     if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
941       post_trans_ind = {2, 0, 1};
942
943     // case 2: only H or W is reduced
944     if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
945       post_trans_ind = {0, 2, 1};
946
947     // case 3: N and either H or W are reduced
948     if (num_reduced_indices == 2)
949       post_trans_ind = {1, 0};
950
951     auto post_trans = create_Nd_transpose(node, post_trans_ind);
952     loco::replace(node).with(post_trans);
953
954     post_trans->a(node);
955
956     return true;
957   }
958
959   bool visit(luci::CircleMinimum *node)
960   {
961     if ((not is_const(node->x())) and is_scalar_const(node->y()))
962     {
963       auto pre_trans = create_pre_transpose(node);
964       pre_trans->a(node->x());
965       node->x(pre_trans);
966     }
967     else if (is_scalar_const(node->x()) and (not is_const(node->y())))
968     {
969       auto pre_trans = create_pre_transpose(node);
970       pre_trans->a(node->y());
971       node->y(pre_trans);
972     }
973     else
974     {
975       // TODO support other cases
976       return false;
977     }
978
979     // Do shape inference for this node again.
980     node->shape_status(luci::ShapeStatus::UNDEFINED);
981
982     auto post_trans = create_post_transpose(node);
983     loco::replace(node).with(post_trans);
984
985     post_trans->a(node);
986     return true;
987   }
988
989   bool visit(luci::CircleMul *node)
990   {
991     LOGGER(l);
992
993     luci::CircleNode *pred_node = nullptr;
994     luci::CircleConst *multiplier = nullptr;
995
996     if (is_NCHW_with_const(node, pred_node, multiplier))
997     {
998       assert(multiplier->rank() == 4); // FIX is_NCHW_with_const unless
999       auto nhwc_const = create_NHWC_from_NCHW(multiplier);
1000       if (nhwc_const == nullptr)
1001         return false;
1002       node->y(nhwc_const);
1003
1004       auto pre_trans = create_pre_transpose(node);
1005       pre_trans->a(pred_node);
1006       node->x(pre_trans);
1007     }
1008     else if (multiplier == nullptr)
1009     {
1010       // Only support for input rank 4
1011       auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1012       if (input_x->rank() != 4)
1013         return false;
1014       auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1015       if (input_y->rank() != 4)
1016         return false;
1017
1018       auto pre_trans_x = create_pre_transpose(node);
1019       pre_trans_x->a(input_x);
1020       node->x(pre_trans_x);
1021
1022       auto pre_trans_y = create_pre_transpose(node);
1023       pre_trans_y->a(input_y);
1024       node->y(pre_trans_y);
1025     }
1026     else
1027     {
1028       return false;
1029     }
1030
1031     // Do shape inference for this node again.
1032     node->shape_status(luci::ShapeStatus::UNDEFINED);
1033
1034     auto post_trans = create_post_transpose(node);
1035     loco::replace(node).with(post_trans);
1036
1037     post_trans->a(node);
1038     return true;
1039   }
1040
1041   bool visit(luci::CircleNeg *node) { return convert_unary_x<luci::CircleNeg>(node); }
1042
1043   bool visit(luci::CirclePad *node)
1044   {
1045     if (!is_NCHW(node))
1046       return false;
1047
1048     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1049     auto pre_trans = create_pre_transpose(node);
1050     pre_trans->a(pred_node);
1051     node->input(pre_trans);
1052
1053     auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1054     const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1055     node->paddings(nhwc_paddings);
1056
1057     // Do shape inference for this node again.
1058     node->shape_status(luci::ShapeStatus::UNDEFINED);
1059
1060     auto post_trans = create_post_transpose(node);
1061     loco::replace(node).with(post_trans);
1062
1063     post_trans->a(node);
1064
1065     return true;
1066   }
1067
1068   bool visit(luci::CirclePadV2 *node)
1069   {
1070     if (!is_NCHW(node))
1071       return false;
1072
1073     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1074     auto pre_trans = create_pre_transpose(node);
1075     pre_trans->a(pred_node);
1076     node->input(pre_trans);
1077
1078     auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1079     const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1080     node->paddings(nhwc_paddings);
1081
1082     // Do shape inference for this node again.
1083     node->shape_status(luci::ShapeStatus::UNDEFINED);
1084
1085     auto post_trans = create_post_transpose(node);
1086     loco::replace(node).with(post_trans);
1087
1088     post_trans->a(node);
1089
1090     return true;
1091   }
1092
1093   // TODO Reduce duplicate code with CircleMean
1094   bool visit(luci::CircleReduceMax *node)
1095   {
1096     auto input = loco::must_cast<luci::CircleNode *>(node->input());
1097     if (input->rank() != 4)
1098       return false;
1099
1100     auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
1101     if (not rindices)
1102       return false;
1103
1104     auto nhwc_rindices = create_NHWC_rindices(rindices);
1105     if (not nhwc_rindices)
1106       return false;
1107
1108     auto pre_trans = create_pre_transpose(node);
1109     pre_trans->a(input);
1110     node->input(pre_trans);
1111
1112     // Do shape inference for this node again.
1113     node->shape_status(luci::ShapeStatus::UNDEFINED);
1114
1115     node->reduction_indices(nhwc_rindices);
1116
1117     if (node->keep_dims())
1118     {
1119       auto post_trans = create_post_transpose(node);
1120       loco::replace(node).with(post_trans);
1121
1122       post_trans->a(node);
1123
1124       return true;
1125     }
1126
1127     // The below codes handle the cases where node->keep_dims() == false
1128     // 1D output never needs a transpose
1129     if (node->rank() <= 1)
1130       return true;
1131
1132     std::vector<bool> reduced_dims_nhwc(4, false);
1133     uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
1134
1135     for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
1136     {
1137       reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
1138     }
1139
1140     // if channel dimension has been reduced, we don't need a transpose
1141     if (reduced_dims_nhwc[3])
1142       return true;
1143
1144     // likewise, if both space dimensions are reduced, no transpose is needed
1145     if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
1146       return true;
1147
1148     std::vector<int32_t> post_trans_ind;
1149     // case 1: only N is reduced
1150     if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
1151       post_trans_ind = {2, 0, 1};
1152
1153     // case 2: only H or W is reduced
1154     if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
1155       post_trans_ind = {0, 2, 1};
1156
1157     // case 3: N and either H or W are reduced
1158     if (num_reduced_indices == 2)
1159       post_trans_ind = {1, 0};
1160
1161     auto post_trans = create_Nd_transpose(node, post_trans_ind);
1162     loco::replace(node).with(post_trans);
1163
1164     post_trans->a(node);
1165
1166     return true;
1167   }
1168
1169   // TODO Reduce duplicate codes with CircleReduceMax
1170   bool visit(luci::CircleReduceMin *node)
1171   {
1172     auto input = loco::must_cast<luci::CircleNode *>(node->input());
1173     if (input->rank() != 4)
1174       return false;
1175
1176     auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
1177     if (not rindices)
1178       return false;
1179
1180     auto nhwc_rindices = create_NHWC_rindices(rindices);
1181     if (not nhwc_rindices)
1182       return false;
1183
1184     auto pre_trans = create_pre_transpose(node);
1185     pre_trans->a(input);
1186     node->input(pre_trans);
1187
1188     // Do shape inference for this node again.
1189     node->shape_status(luci::ShapeStatus::UNDEFINED);
1190
1191     node->reduction_indices(nhwc_rindices);
1192
1193     if (node->keep_dims())
1194     {
1195       auto post_trans = create_post_transpose(node);
1196       loco::replace(node).with(post_trans);
1197
1198       post_trans->a(node);
1199
1200       return true;
1201     }
1202
1203     // The below codes handle the cases where node->keep_dims() == false
1204     // 1D output never needs a transpose
1205     if (node->rank() <= 1)
1206       return true;
1207
1208     std::vector<bool> reduced_dims_nhwc(4, false);
1209     uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
1210
1211     for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
1212     {
1213       reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
1214     }
1215
1216     // if channel dimension has been reduced, we don't need a transpose
1217     if (reduced_dims_nhwc[3])
1218       return true;
1219
1220     // likewise, if both space dimensions are reduced, no transpose is needed
1221     if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
1222       return true;
1223
1224     std::vector<int32_t> post_trans_ind;
1225     // case 1: only N is reduced
1226     if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
1227       post_trans_ind = {2, 0, 1};
1228
1229     // case 2: only H or W is reduced
1230     if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
1231       post_trans_ind = {0, 2, 1};
1232
1233     // case 3: N and either H or W are reduced
1234     if (num_reduced_indices == 2)
1235       post_trans_ind = {1, 0};
1236
1237     auto post_trans = create_Nd_transpose(node, post_trans_ind);
1238     loco::replace(node).with(post_trans);
1239
1240     post_trans->a(node);
1241
1242     return true;
1243   }
1244
1245   bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
1246
1247   bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
1248
1249   bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
1250
1251   bool visit(luci::CircleSplitV *node)
1252   {
1253     // Change split dimension
1254     auto axis = dynamic_cast<luci::CircleConst *>(node->split_dim());
1255     if (not axis)
1256       return false;
1257
1258     if (axis->dtype() != loco::DataType::S32)
1259       return false;
1260
1261     if (axis->size<loco::DataType::S32>() != 1)
1262       return false;
1263
1264     axis->at<loco::DataType::S32>(0) = nchw_axis_to_nhwc(axis->at<loco::DataType::S32>(0));
1265
1266     // Insert pre-transpose
1267     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1268     auto pre_trans = create_pre_transpose(node);
1269     pre_trans->a(pred_node);
1270     node->input(pre_trans);
1271
1272     // Do shape inference for this node again.
1273     node->shape_status(luci::ShapeStatus::UNDEFINED);
1274
1275     // Insert post-transposes
1276     for (auto succ : loco::succs(node))
1277     {
1278       auto svo = loco::must_cast<luci::CircleSplitVOut *>(succ);
1279
1280       auto post_trans = create_post_transpose(svo);
1281       loco::replace(svo).with(post_trans);
1282       post_trans->a(svo);
1283     }
1284
1285     return true;
1286   }
1287
1288   bool visit(luci::CircleSquaredDifference *node)
1289   {
1290     // TODO support CircleConst input
1291     if (dynamic_cast<luci::CircleConst *>(node->x()) != nullptr)
1292       return false;
1293     if (dynamic_cast<luci::CircleConst *>(node->y()) != nullptr)
1294       return false;
1295
1296     auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1297     if (input_x->rank() != 4)
1298       return false;
1299     auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1300     if (input_y->rank() != 4)
1301       return false;
1302
1303     auto pre_trans_x = create_pre_transpose(node);
1304     pre_trans_x->a(input_x);
1305     node->x(pre_trans_x);
1306
1307     auto pre_trans_y = create_pre_transpose(node);
1308     pre_trans_y->a(input_y);
1309     node->y(pre_trans_y);
1310
1311     // Do shape inference for this node again.
1312     node->shape_status(luci::ShapeStatus::UNDEFINED);
1313
1314     auto post_trans = create_post_transpose(node);
1315     loco::replace(node).with(post_trans);
1316
1317     post_trans->a(node);
1318     return true;
1319   }
1320
1321   bool visit(luci::CircleSub *node)
1322   {
1323     luci::CircleNode *pred_node = nullptr;
1324     luci::CircleConst *subtract = nullptr;
1325
1326     auto const_x = dynamic_cast<luci::CircleConst *>(node->x());
1327     auto const_y = dynamic_cast<luci::CircleConst *>(node->y());
1328
1329     if (const_x != nullptr && const_y == nullptr)
1330     {
1331       // case of subtract - pred_node
1332       pred_node = loco::must_cast<luci::CircleNode *>(node->y());
1333       subtract = const_x;
1334
1335       if (!is_NCHW_with_const(node, pred_node, subtract))
1336         return false;
1337
1338       auto pre_trans = create_pre_transpose(node);
1339       pre_trans->a(pred_node);
1340
1341       if (subtract->rank() == 4)
1342       {
1343         auto nhwc_const = create_NHWC_from_NCHW(subtract);
1344         if (nhwc_const == nullptr)
1345           return false;
1346         node->x(nhwc_const);
1347       }
1348       node->y(pre_trans);
1349     }
1350     else if (const_x == nullptr && const_y != nullptr)
1351     {
1352       // case of pred_node - subtract
1353       pred_node = loco::must_cast<luci::CircleNode *>(node->x());
1354       subtract = const_y;
1355
1356       if (!is_NCHW_with_const(node, pred_node, subtract))
1357         return false;
1358
1359       auto pre_trans = create_pre_transpose(node);
1360       pre_trans->a(pred_node);
1361
1362       if (subtract->rank() == 4)
1363       {
1364         auto nhwc_const = create_NHWC_from_NCHW(subtract);
1365         if (nhwc_const == nullptr)
1366           return false;
1367         node->y(nhwc_const);
1368       }
1369
1370       node->x(pre_trans);
1371     }
1372     else if (const_x == nullptr && const_y == nullptr)
1373     {
1374       // Both inputs are not constant.
1375       // In this case, we cannot distinguish NCHW from NHWC,
1376       // so just insert Transpose Ops.
1377       // Only support for input rank 4.
1378       auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1379       if (input_x->rank() != 4)
1380         return false;
1381       auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1382       if (input_y->rank() != 4)
1383         return false;
1384
1385       auto pre_trans_x = create_pre_transpose(node);
1386       pre_trans_x->a(input_x);
1387       node->x(pre_trans_x);
1388
1389       auto pre_trans_y = create_pre_transpose(node);
1390       pre_trans_y->a(input_y);
1391       node->y(pre_trans_y);
1392     }
1393
1394     // Do shape inference for this node again.
1395     node->shape_status(luci::ShapeStatus::UNDEFINED);
1396
1397     auto post_trans = create_post_transpose(node);
1398     loco::replace(node).with(post_trans);
1399
1400     post_trans->a(node);
1401     return true;
1402   }
1403 };
1404
1405 } // namespace
1406
1407 namespace luci
1408 {
1409
1410 bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
1411 {
1412   LOGGER(l);
1413   INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl;
1414
1415   // Annotate NHWC operators
1416   // NHWC operators are detected by pattern matching
1417   //
1418   // Pattern
1419   //    pre-Transose (or pre-Reshape) + [intermediate Ops] + post-Transpose (or post-Reshape)
1420   //
1421   // [intermediate Ops] are annotated as NHWC
1422   //
1423   // NOTE A single pre-Transpose/Reshape can have multiple post-Transpose/Reshape.
1424   // For example,
1425   // pre-Transpose --- [intermediate Ops] --- post-Transpose
1426   //                |
1427   //                +--[intermediate Ops] --- post-Transpose
1428   //
1429   // NOTE Intermediate Ops SHOULD NOT contain pre-Transpose/Reshape
1430   for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
1431   {
1432     if (has_data_format(node))
1433       continue;
1434
1435     if (is_pre_transpose(node) || is_pre_reshape(node))
1436     {
1437       std::set<loco::Node *> intermediate;
1438
1439       // Variable to check intermediate Ops contain pre-Transpose/Reshape
1440       bool has_pre = false;
1441
1442       // Variable to check the pattern is closed with post-Transpose/Reshape
1443       bool is_closed = true;
1444
1445       // For recursive call of lambda
1446       std::function<void(loco::Node *)> collect_intermediate;
1447       collect_intermediate = [&](loco::Node *n) {
1448         for (auto succ : loco::succs(n))
1449         {
1450           // Skip unnecessary traversal
1451           if (intermediate.find(succ) != intermediate.end())
1452             continue;
1453
1454           // Exit condition
1455           if (is_post_transpose(succ) || is_post_reshape(succ))
1456             continue;
1457
1458           if (is_pre_transpose(succ) || is_pre_reshape(succ))
1459           {
1460             has_pre = true;
1461             break;
1462           }
1463
1464           if (is_output(succ))
1465           {
1466             is_closed = false;
1467             break;
1468           }
1469
1470           intermediate.emplace(succ);
1471
1472           collect_intermediate(succ);
1473         }
1474       };
1475
1476       collect_intermediate(node);
1477
1478       if (has_pre or not is_closed)
1479         continue;
1480
1481       for (auto inter : intermediate)
1482       {
1483         if (not has_data_format(inter))
1484           set_data_format(inter, DataFormat::NHWC);
1485       }
1486     }
1487   }
1488
1489   // Annotate NCHW operators
1490   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1491   {
1492     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1493     switch (circle_node->opcode())
1494     {
1495       // List of supported Ops
1496       case luci::CircleOpcode::CIRCLEINPUT:
1497         if (!_preserve_input && !has_data_format(node))
1498         {
1499           set_data_format(node, DataFormat::NCHW);
1500         }
1501         break;
1502       case luci::CircleOpcode::CIRCLEOUTPUT:
1503         if (!_preserve_output && !has_data_format(node))
1504         {
1505           set_data_format(node, DataFormat::NCHW);
1506         }
1507         break;
1508       // SOFTMAX, LOG_SOFTMAX are not converted, because
1509       // tflite/circle assumes the last channel is always axis
1510       case luci::CircleOpcode::ADD:
1511       case luci::CircleOpcode::CONCATENATION:
1512       case luci::CircleOpcode::ELU:
1513       case luci::CircleOpcode::LEAKY_RELU:
1514       case luci::CircleOpcode::LOGISTIC:
1515       case luci::CircleOpcode::MAXIMUM:
1516       case luci::CircleOpcode::MEAN:
1517       case luci::CircleOpcode::MINIMUM:
1518       case luci::CircleOpcode::MUL:
1519       case luci::CircleOpcode::NEG:
1520       case luci::CircleOpcode::PAD:
1521       case luci::CircleOpcode::PADV2:
1522       case luci::CircleOpcode::REDUCE_MAX:
1523       case luci::CircleOpcode::REDUCE_MIN:
1524       case luci::CircleOpcode::RELU:
1525       case luci::CircleOpcode::RELU6:
1526       case luci::CircleOpcode::RSQRT:
1527       case luci::CircleOpcode::SPLIT_V:
1528       case luci::CircleOpcode::SQUARED_DIFFERENCE:
1529       case luci::CircleOpcode::SUB:
1530         if (!has_data_format(node))
1531         {
1532           set_data_format(node, DataFormat::NCHW);
1533         }
1534         break;
1535       default:
1536         break;
1537     }
1538   }
1539
1540   bool changed = false;
1541   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1542   {
1543     if (!has_data_format(node))
1544     {
1545       // Unsupported Op
1546       continue;
1547     }
1548     else if (get_data_format(node) == DataFormat::NHWC)
1549     {
1550       // Already converted to NHWC
1551       continue;
1552     }
1553     else if (has_dynamic_shape(node))
1554     {
1555       // This pass only works for static-shaped node
1556       INFO(l) << "Skip the node with a dynamic shape." << std::endl;
1557       continue;
1558     }
1559     else
1560     {
1561       ConvertNCHWToNHWC converter;
1562       auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1563       if (circle_node->rank() != 4)
1564       {
1565         // TODO replace the check above with the input rank check, and remove the condition below
1566         if (not dynamic_cast<luci::CircleMean *>(node) and
1567             not dynamic_cast<luci::CircleReduceMax *>(node) and
1568             not dynamic_cast<luci::CircleReduceMin *>(node))
1569           continue;
1570       }
1571
1572       if (circle_node->accept(&converter))
1573       {
1574         set_data_format(node, DataFormat::NHWC);
1575         changed = true;
1576       }
1577       else
1578       {
1579         continue;
1580       }
1581     }
1582   }
1583
1584   INFO(l) << "ConvertNCHWToNHWCPass End" << std::endl;
1585   return changed;
1586 }
1587
1588 } // namespace luci