Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ConvertNCHWToNHWCPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
18 #include "CircleOptimizerUtils.h"
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/Service/Nodes/CircleConst.h>
24 #include <luci/Log.h>
25
26 #include <functional>
27
28 namespace
29 {
30
31 bool is_same_shape(const luci::CircleNode *node, const std::vector<loco::Dimension> &shape)
32 {
33   if (not node)
34     return false;
35
36   if (shape.size() != node->rank())
37     return false;
38
39   for (uint32_t i = 0; i < shape.size(); i++)
40   {
41     if (not(node->dim(i) == shape[i]))
42       return false;
43   }
44   return true;
45 }
46
47 enum class DataFormat
48 {
49   NCHW,
50   NHWC
51 };
52
53 /**
54  * @brief Set annotation for DataFormat (NCHW, NHWC)
55  *
56  * @note DataFormatAnnotation will live longer than this Pass (until the
57  *       annotated loco::Node is erased). So, do not use large data in the
58  *       annotation to avoid excessive memory usage.
59  */
60 class DataFormatAnnotation final : public loco::NodeAnnotation
61 {
62 public:
63   DataFormatAnnotation(const DataFormat &format) : _format{format}
64   {
65     // DO NOTHING
66   }
67
68 public:
69   const DataFormat &format(void) const { return _format; }
70
71 private:
72   DataFormat _format;
73 };
74
75 void set_data_format(loco::Node *node, const DataFormat &format)
76 {
77   node->annot(std::make_unique<DataFormatAnnotation>(format));
78 }
79
80 DataFormat get_data_format(loco::Node *node)
81 {
82   assert(node->annot<DataFormatAnnotation>() != nullptr);
83   return node->annot<DataFormatAnnotation>()->format();
84 }
85
86 bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; }
87
88 bool check_4d_transpose(loco::Node *node, const std::vector<int32_t> indices)
89 {
90   assert(indices.size() == 4);
91
92   auto trans = dynamic_cast<luci::CircleTranspose *>(node);
93   if (not trans)
94     return false;
95
96   if (not trans->perm())
97     return false;
98
99   auto perm = dynamic_cast<luci::CircleConst *>(trans->perm());
100   // Only const perm is supported
101   if (not perm)
102     return false;
103
104   if (perm->dtype() != loco::DataType::S32)
105     return false;
106
107   if (perm->size<loco::DataType::S32>() != 4)
108     return false;
109
110   for (uint32_t i = 0; i < 4; i++)
111   {
112     if (perm->at<loco::DataType::S32>(i) != indices[i])
113       return false;
114   }
115
116   return true;
117 }
118
119 luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
120                                            const std::vector<int32_t> indices)
121 {
122   assert(indices.size() == 4);
123
124   auto name = node->name();
125   assert(name.length() > 0);
126
127   auto perm = node->graph()->nodes()->create<luci::CircleConst>();
128   perm->dtype(loco::DataType::S32);
129   perm->size<loco::DataType::S32>(4);
130   perm->rank(1);
131   perm->dim(0) = 4;
132   for (uint32_t i = 0; i < 4; i++)
133     perm->at<loco::DataType::S32>(i) = indices[i];
134   perm->shape_status(luci::ShapeStatus::VALID);
135
136   auto make_string = [](const std::vector<int32_t> &nums) {
137     std::string str;
138     for (auto num : nums)
139     {
140       if (str.length() > 0)
141         str += ".";
142       str += std::to_string(num);
143     }
144     return str;
145   };
146
147   auto str_indices = make_string(indices);
148
149   perm->name(name + "/Transpose_" + str_indices + "/perm");
150
151   auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
152   trans->perm(perm);
153   trans->name(name + "/Transpose_" + str_indices);
154   luci::add_origin(trans, luci::get_origin(node));
155
156   return trans;
157 }
158
159 luci::CircleTranspose *create_Nd_transpose(luci::CircleNode *node,
160                                            const std::vector<int32_t> indices)
161 {
162   auto name = node->name();
163   assert(name.length() > 0);
164
165   auto perm = node->graph()->nodes()->create<luci::CircleConst>();
166   perm->dtype(loco::DataType::S32);
167   perm->size<loco::DataType::S32>(indices.size());
168   perm->rank(1);
169   perm->dim(0) = indices.size();
170   for (uint32_t i = 0; i < indices.size(); i++)
171     perm->at<loco::DataType::S32>(i) = indices[i];
172   perm->shape_status(luci::ShapeStatus::VALID);
173
174   auto make_string = [](const std::vector<int32_t> &nums) {
175     std::string str;
176     for (auto num : nums)
177     {
178       if (str.length() > 0)
179         str += ".";
180       str += std::to_string(num);
181     }
182     return str;
183   };
184
185   auto str_indices = make_string(indices);
186
187   perm->name(name + "/Transpose_" + str_indices + "/perm");
188
189   auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
190   trans->perm(perm);
191   trans->name(name + "/Transpose_" + str_indices);
192   luci::add_origin(trans, luci::get_origin(node));
193
194   return trans;
195 }
196
197 int32_t nchw_axis_to_nhwc(int32_t axis)
198 {
199   uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4);
200   static const uint32_t to_nhwc[4] = {0, 3, 1, 2};
201   if (pos_axis > 3)
202     throw std::runtime_error("Concat axis must be in range [-4, 4)");
203   return to_nhwc[pos_axis];
204 }
205
206 luci::CircleTranspose *create_post_transpose(luci::CircleNode *node)
207 {
208   return create_4d_transpose(node, {0, 3, 1, 2});
209 }
210
211 luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node)
212 {
213   return create_4d_transpose(node, {0, 2, 3, 1});
214 }
215
216 bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
217 {
218   assert(indices.size() == 4); // FIX_CALLER_UNLESS
219
220   auto reshape = dynamic_cast<luci::CircleReshape *>(node);
221   if (not reshape)
222     return false;
223
224   if (reshape->rank() != 4)
225     return false;
226
227   auto input = loco::must_cast<luci::CircleNode *>(reshape->tensor());
228   if (input->shape_status() != luci::ShapeStatus::VALID)
229     return false;
230
231   if (reshape->shape_status() != luci::ShapeStatus::VALID)
232     return false;
233
234   if (!(input->dim(0) == reshape->dim(indices[0])) ||
235       !(input->dim(1) == reshape->dim(indices[1])) ||
236       !(input->dim(2) == reshape->dim(indices[2])) || !(input->dim(3) == reshape->dim(indices[3])))
237     return false;
238
239   return true;
240 }
241
242 // Check if Reshape that converts NCHW -> NHWC
243 bool is_pre_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 3, 1, 2}); }
244
245 // Check if Reshape that converts NHWC -> NCHW
246 bool is_post_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 2, 3, 1}); }
247
248 bool is_post_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 3, 1, 2}); }
249
250 bool is_pre_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 2, 3, 1}); }
251
252 uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices)
253 {
254   return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
255            dimension.dim(3).value() +
256          indices[1] * dimension.dim(2).value() * dimension.dim(3).value() +
257          indices[2] * dimension.dim(3).value() + indices[3];
258 }
259
260 luci::CircleConst *create_NHWC_paddings(luci::CircleConst *paddings)
261 {
262   // paddings shape is (4,2) (it was checked by is_NCHW)
263   assert(paddings != nullptr);
264   assert(paddings->rank() == 2);
265   assert(paddings->dim(0).value() == 4);
266   assert(paddings->dim(1).value() == 2);
267
268   // paddings for idx 0~3 are 0 (checked by is_NCHW)
269   assert(paddings->at<loco::DataType::S32>(0) == 0);
270   assert(paddings->at<loco::DataType::S32>(1) == 0);
271   assert(paddings->at<loco::DataType::S32>(2) == 0);
272   assert(paddings->at<loco::DataType::S32>(3) == 0);
273
274   auto name = paddings->name();
275   assert(name.length() > 0);
276
277   auto nhwc_paddings = paddings->graph()->nodes()->create<luci::CircleConst>();
278   nhwc_paddings->dtype(loco::DataType::S32);
279   nhwc_paddings->shape({4, 2});
280   nhwc_paddings->shape_status(luci::ShapeStatus::VALID);
281   nhwc_paddings->size<loco::DataType::S32>(4 * 2);
282   nhwc_paddings->name(name + "_NHWC");
283
284   for (uint32_t dim = 0; dim < 4; dim++)
285   {
286     for (uint32_t i = 0; i < 2; i++)
287     {
288       int32_t data = 0;
289
290       if (dim == 1)
291       {
292         // get third dimension (H in NCHW)
293         data = paddings->at<loco::DataType::S32>(2 * 2 + i);
294       }
295       else if (dim == 2)
296       {
297         // get fourth dimension (W in NCHW)
298         data = paddings->at<loco::DataType::S32>(3 * 2 + i);
299       }
300
301       nhwc_paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
302     }
303   }
304   return nhwc_paddings;
305 }
306
307 luci::CircleConst *create_NHWC_rindices(luci::CircleConst *rindices)
308 {
309   assert(rindices != nullptr); // FIX_CALLER_UNLESS
310
311   if (rindices->dtype() != loco::DataType::S32)
312     return nullptr;
313
314   auto nhwc_rindices = luci::clone(rindices);
315   auto name = rindices->name();
316   assert(name.length() > 0); // FIX_CALLER_UNLESS
317   nhwc_rindices->name(name + "_NHWC");
318
319   auto size = nhwc_rindices->size<loco::DataType::S32>();
320   for (uint32_t i = 0; i < size; i++)
321   {
322     nhwc_rindices->at<loco::DataType::S32>(i) =
323       nchw_axis_to_nhwc(rindices->at<loco::DataType::S32>(i));
324   }
325
326   return nhwc_rindices;
327 }
328
329 luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant)
330 {
331   LOGGER(l);
332   assert(constant->rank() == 4);
333
334   // TODO: Support non-float types
335   if (constant->dtype() != loco::DataType::FLOAT32)
336   {
337     INFO(l) << "Non-float type constant: " << constant->name() << std::endl;
338     return nullptr;
339   }
340
341   loco::TensorShape nchw_dimension{constant->dim(0), constant->dim(1), constant->dim(2),
342                                    constant->dim(3)};
343   loco::TensorShape nhwc_dimension{constant->dim(0), constant->dim(2), constant->dim(3),
344                                    constant->dim(1)};
345
346   auto name = constant->name();
347   assert(name.length() > 0);
348
349   auto nhwc_const = constant->graph()->nodes()->create<luci::CircleConst>();
350   nhwc_const->dtype(constant->dtype());
351   nhwc_const->rank(4);
352   nhwc_const->dim(0).set(constant->dim(0).value());
353   nhwc_const->dim(1).set(constant->dim(2).value());
354   nhwc_const->dim(2).set(constant->dim(3).value());
355   nhwc_const->dim(3).set(constant->dim(1).value());
356   nhwc_const->shape_status(luci::ShapeStatus::VALID);
357   nhwc_const->size<loco::DataType::FLOAT32>(constant->size<loco::DataType::FLOAT32>());
358   nhwc_const->name(name + "_NHWC");
359
360   for (uint32_t n = 0; n < nchw_dimension.dim(0).value(); n++)
361   {
362     for (uint32_t c = 0; c < nchw_dimension.dim(1).value(); c++)
363     {
364       for (uint32_t h = 0; h < nchw_dimension.dim(2).value(); h++)
365       {
366         for (uint32_t w = 0; w < nchw_dimension.dim(3).value(); w++)
367         {
368           uint32_t nchw_indices[4] = {n, c, h, w};
369           uint32_t nhwc_indices[4] = {n, h, w, c};
370           auto data =
371             constant->at<loco::DataType::FLOAT32>(cal_offset(nchw_dimension, nchw_indices));
372           nhwc_const->at<loco::DataType::FLOAT32>(cal_offset(nhwc_dimension, nhwc_indices)) = data;
373         }
374       }
375     }
376   }
377   return nhwc_const;
378 }
379
380 // NOTE Following conditions can be extended later
381 //
382 // Find PAD with an NCHW pattern described below
383 //   - Paddings shape : [4, 2]
384 //   - Paddings value : [[0, 0], [0, 0], [h_t, h_b], [w_t, w_b]]]
385 bool is_NCHW(const luci::CirclePad *node)
386 {
387   const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
388   // Non-const paddings is not supported
389   if (paddings == nullptr)
390     return false;
391
392   if (paddings->rank() != 2)
393     return false;
394
395   if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
396     return false;
397
398   // Only check the first two dimensions
399   for (uint32_t dim = 0; dim < 2; dim++)
400   {
401     for (uint32_t i = 0; i < 2; i++)
402     {
403       auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
404       if (data != 0)
405         return false;
406     }
407   }
408
409   return true;
410 }
411
412 // NOTE Copied from is_NCHW(CirclePad)
413 bool is_NCHW(const luci::CirclePadV2 *node)
414 {
415   const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
416   // Non-const paddings is not supported
417   if (paddings == nullptr)
418     return false;
419
420   if (paddings->rank() != 2)
421     return false;
422
423   if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
424     return false;
425
426   // Only check the first two dimensions
427   for (uint32_t dim = 0; dim < 2; dim++)
428   {
429     for (uint32_t i = 0; i < 2; i++)
430     {
431       auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
432       if (data != 0)
433         return false;
434     }
435   }
436
437   return true;
438 }
439
440 // NOTE Following conditions can be extended later
441 // NOTE Used for Maximum, Miminum as ReLU/ReLU6
442 //
443 // Find T with an NCHW pattern described below
444 //   - Input (non-constant) shape : [N, C, H, W]
445 //   - Input (constant) shape : [1] or []
446 //   - Output shape : [N, C, H, W]
447 template <class T>
448 bool is_NCHW_with_s_const(const T *node, luci::CircleNode *&pred_node,
449                           luci::CircleConst *&comp_const)
450 {
451   auto x = dynamic_cast<luci::CircleConst *>(node->x());
452   auto y = dynamic_cast<luci::CircleConst *>(node->y());
453
454   if (x != nullptr && y == nullptr)
455   {
456     pred_node = loco::must_cast<luci::CircleNode *>(node->y());
457     comp_const = x;
458   }
459   else if (x == nullptr && y != nullptr)
460   {
461     pred_node = loco::must_cast<luci::CircleNode *>(node->x());
462     comp_const = y;
463   }
464   else
465   {
466     // Ignore if T does not have a comp_const input.
467     return false;
468   }
469
470   if (pred_node->rank() != 4)
471     return false;
472
473   // Check if scalar
474   const auto const_rank = comp_const->rank();
475   if (const_rank == 0 || (const_rank == 1 && comp_const->dim(0).value() == 1))
476     return true;
477   return false;
478 }
479
480 // NOTE Following conditions can be extended later
481 //
482 // Find MUL with an NCHW pattern described below
483 //   - Input (non-constant) shape : [N, C, H, W]
484 //   - Input (constant) shape : [1, C, 1, 1], [N, C, H, W] or a scalar (1)
485 //   - Output shape : [N, C, H, W]
486 bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
487                         luci::CircleConst *&multiplier)
488 {
489   auto x = dynamic_cast<luci::CircleConst *>(node->x());
490   auto y = dynamic_cast<luci::CircleConst *>(node->y());
491
492   if (x != nullptr && y == nullptr)
493   {
494     pred_node = loco::must_cast<luci::CircleNode *>(node->y());
495     multiplier = x;
496   }
497   else if (x == nullptr && y != nullptr)
498   {
499     pred_node = loco::must_cast<luci::CircleNode *>(node->x());
500     multiplier = y;
501   }
502   else
503   {
504     // Ignore if MUL does not have a multiplier input.
505     return false;
506   }
507
508   if (pred_node->rank() != 4)
509     return false;
510
511   const auto const_rank = multiplier->rank();
512   // Support Rank 4 or scalar (rank 0 or 1)
513   if (const_rank != 4 && const_rank != 0 && const_rank != 1)
514     return false;
515
516   const auto input_cdim = pred_node->dim(1);
517   const auto output_cdim = node->dim(1);
518
519   if (const_rank == 4)
520   {
521     bool supported_shape = false;
522
523     // Check multiplier is (1, C, 1, 1)
524     if (is_same_shape(multiplier, {1, node->dim(1), 1, 1}))
525       supported_shape = true;
526
527     // Check multiplier is (N, C, H, W)
528     if (is_same_shape(multiplier, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
529       supported_shape = true;
530
531     return supported_shape;
532   }
533   if (input_cdim == output_cdim)
534     return true;
535   else
536     return false;
537 }
538
539 // We assume ADD with const input is NCHW if,
540 // Input shape: (N, C, H, W)
541 // Output shape: (N, C, H, W)
542 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
543 // 2. Input, Output, Const have the same C.
544 bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node,
545                         luci::CircleConst *&beta)
546 {
547   auto x = dynamic_cast<luci::CircleConst *>(node->x());
548   auto y = dynamic_cast<luci::CircleConst *>(node->y());
549
550   if (x != nullptr && y == nullptr)
551   {
552     pred_node = loco::must_cast<luci::CircleNode *>(node->y());
553     beta = x;
554   }
555   else if (x == nullptr && y != nullptr)
556   {
557     pred_node = loco::must_cast<luci::CircleNode *>(node->x());
558     beta = y;
559   }
560   else
561   {
562     // Ignore if ADD does not have a constant input.
563     return false;
564   }
565
566   if (pred_node->rank() != 4)
567     return false;
568
569   const auto const_rank = beta->rank();
570   // Support Rank 4 or scalar (rank 0 or 1)
571   if (const_rank != 4 && const_rank != 0 && const_rank != 1)
572     return false;
573
574   const auto input_cdim = pred_node->dim(1);
575   const auto output_cdim = node->dim(1);
576
577   if (const_rank == 4)
578   {
579     bool supported_shape = false;
580
581     // Check beta is (1, C, 1, 1)
582     if (is_same_shape(beta, {1, node->dim(1), 1, 1}))
583       supported_shape = true;
584
585     // Check beta is (N, C, H, W)
586     if (is_same_shape(beta, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
587       supported_shape = true;
588
589     return supported_shape;
590   }
591   if (input_cdim == output_cdim)
592     return true;
593   else
594     return false;
595 }
596
597 // We assume SUB with const input is NCHW if,
598 // Input shape: (N, C, H, W)
599 // Output shape: (N, C, H, W)
600 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
601 // 2. Input, Output, Const have the same C.
602 bool is_NCHW_with_const(const luci::CircleSub *node, const luci::CircleNode *pred_node,
603                         const luci::CircleConst *subtract)
604 {
605   assert(pred_node != nullptr);
606   assert(subtract != nullptr);
607
608   if (pred_node->rank() != 4)
609     return false;
610
611   const auto const_rank = subtract->rank();
612   // Support Rank 4 or scalar (rank 0 or 1)
613   if (const_rank != 4 && const_rank != 0 && const_rank != 1)
614     return false;
615
616   const auto input_cdim = pred_node->dim(1);
617   const auto output_cdim = node->dim(1);
618
619   if (const_rank == 4)
620   {
621     bool supported_shape = false;
622
623     // Check subtract is (1, C, 1, 1)
624     if (is_same_shape(subtract, {1, node->dim(1), 1, 1}))
625       supported_shape = true;
626
627     // Check subtract is (N, C, H, W)
628     if (is_same_shape(subtract, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
629       supported_shape = true;
630
631     return supported_shape;
632   }
633   if (input_cdim == output_cdim)
634     return true;
635   else
636     return false;
637 }
638
639 template <class T> bool convert_unary_features(T *node)
640 {
641   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->features());
642   auto pre_trans = create_pre_transpose(node);
643   pre_trans->a(pred_node);
644   node->features(pre_trans);
645
646   // Do shape inference for this node again.
647   node->shape_status(luci::ShapeStatus::UNDEFINED);
648
649   auto post_trans = create_post_transpose(node);
650   loco::replace(node).with(post_trans);
651
652   post_trans->a(node);
653
654   return true;
655 }
656
657 template <class T> bool convert_unary_x(T *node)
658 {
659   const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
660   auto pre_trans = create_pre_transpose(node);
661   pre_trans->a(pred_node);
662   node->x(pre_trans);
663
664   // Do shape inference for this node again.
665   node->shape_status(luci::ShapeStatus::UNDEFINED);
666
667   auto post_trans = create_post_transpose(node);
668   loco::replace(node).with(post_trans);
669
670   post_trans->a(node);
671
672   return true;
673 }
674
675 class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
676 {
677   // Default
678   bool visit(luci::CircleNode *node)
679   {
680     throw std::runtime_error(node->name() + " is an unsupported operator.");
681   }
682
683   bool visit(luci::CircleInput *node)
684   {
685     const auto n = node->dim(0);
686     const auto c = node->dim(1);
687     const auto h = node->dim(2);
688     const auto w = node->dim(3);
689
690     node->dim(1) = h;
691     node->dim(2) = w;
692     node->dim(3) = c;
693
694     // Do shape inference for this node again.
695     node->shape_status(luci::ShapeStatus::UNDEFINED);
696
697     // Insert post-tranpose
698     auto post_trans = create_post_transpose(node);
699     loco::replace(node).with(post_trans);
700
701     post_trans->a(node);
702
703     // Update graph input
704     auto graph_inputs = node->graph()->inputs();
705     auto graph_input = graph_inputs->at(node->index());
706     graph_input->shape({n, h, w, c});
707
708     return true;
709   }
710
711   bool visit(luci::CircleOutput *node)
712   {
713     // Insert pre-transpose
714     auto pre_trans = create_pre_transpose(node);
715     pre_trans->a(node->from());
716
717     node->from(pre_trans);
718
719     // Do shape inference for this node again.
720     node->shape_status(luci::ShapeStatus::UNDEFINED);
721
722     // Update graph output
723     const auto n = node->dim(0).value();
724     const auto c = node->dim(1).value();
725     const auto h = node->dim(2).value();
726     const auto w = node->dim(3).value();
727
728     auto graph_outputs = node->graph()->outputs();
729     auto graph_output = graph_outputs->at(node->index());
730     graph_output->shape({n, h, w, c});
731
732     return true;
733   }
734
735   bool visit(luci::CircleAdd *node)
736   {
737     luci::CircleNode *pred_node = nullptr;
738     luci::CircleConst *beta = nullptr;
739
740     if (is_NCHW_with_const(node, pred_node, beta))
741     {
742       auto pre_trans = create_pre_transpose(node);
743       pre_trans->a(pred_node);
744
745       if (beta->rank() == 4)
746       {
747         auto nhwc_const = create_NHWC_from_NCHW(beta);
748         if (nhwc_const == nullptr)
749           return false;
750         node->y(nhwc_const);
751       }
752
753       node->x(pre_trans);
754     }
755     else if (beta == nullptr)
756     {
757       // Both inputs are not constant.
758       // In this case, we cannot distinguish NCHW from NHWC,
759       // so just insert Transpose Ops.
760       auto pre_trans_x = create_pre_transpose(node);
761       pre_trans_x->a(node->x());
762       node->x(pre_trans_x);
763
764       auto pre_trans_y = create_pre_transpose(node);
765       pre_trans_y->a(node->y());
766       node->y(pre_trans_y);
767     }
768     else
769     {
770       return false;
771     }
772
773     // Do shape inference for this node again.
774     node->shape_status(luci::ShapeStatus::UNDEFINED);
775
776     auto post_trans = create_post_transpose(node);
777     loco::replace(node).with(post_trans);
778
779     post_trans->a(node);
780     return true;
781   }
782
783   bool visit(luci::CircleConcatenation *node)
784   {
785     const auto num_values = node->numValues();
786     for (uint32_t i = 0; i < num_values; i++)
787     {
788       auto pred_node = loco::must_cast<luci::CircleNode *>(node->values(i));
789       auto pre_trans = create_pre_transpose(node);
790       pre_trans->a(pred_node);
791       node->values(i, pre_trans);
792     }
793
794     // Do shape inference for this node again.
795     node->shape_status(luci::ShapeStatus::UNDEFINED);
796
797     node->axis(nchw_axis_to_nhwc(node->axis()));
798
799     auto post_trans = create_post_transpose(node);
800     loco::replace(node).with(post_trans);
801
802     post_trans->a(node);
803
804     return true;
805   }
806
807   bool visit(luci::CircleLeakyRelu *node)
808   {
809     return convert_unary_features<luci::CircleLeakyRelu>(node);
810   }
811
812   bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
813
814   bool visit(luci::CircleMaximum *node)
815   {
816     luci::CircleNode *pred_node = nullptr;
817     luci::CircleConst *comp_constant = nullptr;
818
819     if (is_NCHW_with_s_const<luci::CircleMaximum>(node, pred_node, comp_constant))
820     {
821       auto pre_trans = create_pre_transpose(node);
822       pre_trans->a(pred_node);
823       node->x(pre_trans);
824     }
825     else
826     {
827       // TODO support other cases
828       return false;
829     }
830
831     // Do shape inference for this node again.
832     node->shape_status(luci::ShapeStatus::UNDEFINED);
833
834     auto post_trans = create_post_transpose(node);
835     loco::replace(node).with(post_trans);
836
837     post_trans->a(node);
838     return true;
839   }
840
841   bool visit(luci::CircleMean *node)
842   {
843     auto input = loco::must_cast<luci::CircleNode *>(node->input());
844     if (input->rank() != 4)
845       return false;
846
847     auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
848     if (not rindices)
849       return false;
850
851     auto nhwc_rindices = create_NHWC_rindices(rindices);
852     if (not nhwc_rindices)
853       return false;
854
855     auto pre_trans = create_pre_transpose(node);
856     pre_trans->a(input);
857     node->input(pre_trans);
858
859     // Do shape inference for this node again.
860     node->shape_status(luci::ShapeStatus::UNDEFINED);
861
862     node->reduction_indices(nhwc_rindices);
863
864     if (node->keep_dims())
865     {
866       auto post_trans = create_post_transpose(node);
867       loco::replace(node).with(post_trans);
868
869       post_trans->a(node);
870
871       return true;
872     }
873
874     // node->keep_dims() == false
875     // 1D output never needs a transpose
876     if (node->rank() <= 1)
877       return true;
878
879     std::vector<bool> reduced_dims_nhwc(4, false);
880     uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
881
882     for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
883     {
884       reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
885     }
886
887     // if channel dimension has been reduced, we don't need a transpose
888     if (reduced_dims_nhwc[3])
889       return true;
890
891     // likewise, if both space dimensions are reduced, no transpose is needed
892     if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
893       return true;
894
895     std::vector<int32_t> post_trans_ind;
896     // case 1: only N is reduced
897     if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
898       post_trans_ind = {2, 0, 1};
899
900     // case 2: only H or W is reduced
901     if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
902       post_trans_ind = {0, 2, 1};
903
904     // case 3: N and either H or W are reduced
905     if (num_reduced_indices == 2)
906       post_trans_ind = {1, 0};
907
908     auto post_trans = create_Nd_transpose(node, post_trans_ind);
909     loco::replace(node).with(post_trans);
910
911     post_trans->a(node);
912
913     return true;
914   }
915
916   bool visit(luci::CircleMinimum *node)
917   {
918     luci::CircleNode *pred_node = nullptr;
919     luci::CircleConst *comp_constant = nullptr;
920
921     if (is_NCHW_with_s_const<luci::CircleMinimum>(node, pred_node, comp_constant))
922     {
923       auto pre_trans = create_pre_transpose(node);
924       pre_trans->a(pred_node);
925       node->x(pre_trans);
926     }
927     else
928     {
929       // TODO support other cases
930       return false;
931     }
932
933     // Do shape inference for this node again.
934     node->shape_status(luci::ShapeStatus::UNDEFINED);
935
936     auto post_trans = create_post_transpose(node);
937     loco::replace(node).with(post_trans);
938
939     post_trans->a(node);
940     return true;
941   }
942
943   bool visit(luci::CircleMul *node)
944   {
945     LOGGER(l);
946
947     luci::CircleNode *pred_node = nullptr;
948     luci::CircleConst *multiplier = nullptr;
949
950     if (is_NCHW_with_const(node, pred_node, multiplier))
951     {
952       auto pre_trans = create_pre_transpose(node);
953       pre_trans->a(pred_node);
954       node->x(pre_trans);
955
956       if (multiplier->rank() == 4)
957       {
958         auto nhwc_const = create_NHWC_from_NCHW(multiplier);
959         node->y(nhwc_const);
960       }
961     }
962     else if (multiplier == nullptr)
963     {
964       // Only support for input rank 4
965       auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
966       if (input_x->rank() != 4)
967         return false;
968       auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
969       if (input_y->rank() != 4)
970         return false;
971
972       auto pre_trans_x = create_pre_transpose(node);
973       pre_trans_x->a(input_x);
974       node->x(pre_trans_x);
975
976       auto pre_trans_y = create_pre_transpose(node);
977       pre_trans_y->a(input_y);
978       node->y(pre_trans_y);
979     }
980     else
981     {
982       return false;
983     }
984
985     // Do shape inference for this node again.
986     node->shape_status(luci::ShapeStatus::UNDEFINED);
987
988     auto post_trans = create_post_transpose(node);
989     loco::replace(node).with(post_trans);
990
991     post_trans->a(node);
992     return true;
993   }
994
995   bool visit(luci::CircleNeg *node) { return convert_unary_x<luci::CircleNeg>(node); }
996
997   bool visit(luci::CirclePad *node)
998   {
999     if (!is_NCHW(node))
1000       return false;
1001
1002     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1003     auto pre_trans = create_pre_transpose(node);
1004     pre_trans->a(pred_node);
1005     node->input(pre_trans);
1006
1007     auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1008     const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1009     node->paddings(nhwc_paddings);
1010
1011     // Do shape inference for this node again.
1012     node->shape_status(luci::ShapeStatus::UNDEFINED);
1013
1014     auto post_trans = create_post_transpose(node);
1015     loco::replace(node).with(post_trans);
1016
1017     post_trans->a(node);
1018
1019     return true;
1020   }
1021
1022   bool visit(luci::CirclePadV2 *node)
1023   {
1024     if (!is_NCHW(node))
1025       return false;
1026
1027     const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1028     auto pre_trans = create_pre_transpose(node);
1029     pre_trans->a(pred_node);
1030     node->input(pre_trans);
1031
1032     auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1033     const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1034     node->paddings(nhwc_paddings);
1035
1036     // Do shape inference for this node again.
1037     node->shape_status(luci::ShapeStatus::UNDEFINED);
1038
1039     auto post_trans = create_post_transpose(node);
1040     loco::replace(node).with(post_trans);
1041
1042     post_trans->a(node);
1043
1044     return true;
1045   }
1046
1047   bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
1048
1049   bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
1050
1051   bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
1052
1053   bool visit(luci::CircleSquaredDifference *node)
1054   {
1055     // TODO support CircleConst input
1056     if (dynamic_cast<luci::CircleConst *>(node->x()) != nullptr)
1057       return false;
1058     if (dynamic_cast<luci::CircleConst *>(node->y()) != nullptr)
1059       return false;
1060
1061     auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1062     if (input_x->rank() != 4)
1063       return false;
1064     auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1065     if (input_y->rank() != 4)
1066       return false;
1067
1068     auto pre_trans_x = create_pre_transpose(node);
1069     pre_trans_x->a(input_x);
1070     node->x(pre_trans_x);
1071
1072     auto pre_trans_y = create_pre_transpose(node);
1073     pre_trans_y->a(input_y);
1074     node->y(pre_trans_y);
1075
1076     // Do shape inference for this node again.
1077     node->shape_status(luci::ShapeStatus::UNDEFINED);
1078
1079     auto post_trans = create_post_transpose(node);
1080     loco::replace(node).with(post_trans);
1081
1082     post_trans->a(node);
1083     return true;
1084   }
1085
1086   bool visit(luci::CircleSub *node)
1087   {
1088     luci::CircleNode *pred_node = nullptr;
1089     luci::CircleConst *subtract = nullptr;
1090
1091     auto const_x = dynamic_cast<luci::CircleConst *>(node->x());
1092     auto const_y = dynamic_cast<luci::CircleConst *>(node->y());
1093
1094     if (const_x != nullptr && const_y == nullptr)
1095     {
1096       // case of subtract - pred_node
1097       pred_node = loco::must_cast<luci::CircleNode *>(node->y());
1098       subtract = const_x;
1099
1100       if (!is_NCHW_with_const(node, pred_node, subtract))
1101         return false;
1102
1103       auto pre_trans = create_pre_transpose(node);
1104       pre_trans->a(pred_node);
1105
1106       if (subtract->rank() == 4)
1107       {
1108         auto nhwc_const = create_NHWC_from_NCHW(subtract);
1109         if (nhwc_const == nullptr)
1110           return false;
1111         node->x(nhwc_const);
1112       }
1113       node->y(pre_trans);
1114     }
1115     else if (const_x == nullptr && const_y != nullptr)
1116     {
1117       // case of pred_node - subtract
1118       pred_node = loco::must_cast<luci::CircleNode *>(node->x());
1119       subtract = const_y;
1120
1121       if (!is_NCHW_with_const(node, pred_node, subtract))
1122         return false;
1123
1124       auto pre_trans = create_pre_transpose(node);
1125       pre_trans->a(pred_node);
1126
1127       if (subtract->rank() == 4)
1128       {
1129         auto nhwc_const = create_NHWC_from_NCHW(subtract);
1130         if (nhwc_const == nullptr)
1131           return false;
1132         node->y(nhwc_const);
1133       }
1134
1135       node->x(pre_trans);
1136     }
1137     else if (const_x == nullptr && const_y == nullptr)
1138     {
1139       // Both inputs are not constant.
1140       // In this case, we cannot distinguish NCHW from NHWC,
1141       // so just insert Transpose Ops.
1142       // Only support for input rank 4.
1143       auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1144       if (input_x->rank() != 4)
1145         return false;
1146       auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1147       if (input_y->rank() != 4)
1148         return false;
1149
1150       auto pre_trans_x = create_pre_transpose(node);
1151       pre_trans_x->a(input_x);
1152       node->x(pre_trans_x);
1153
1154       auto pre_trans_y = create_pre_transpose(node);
1155       pre_trans_y->a(input_y);
1156       node->y(pre_trans_y);
1157     }
1158
1159     // Do shape inference for this node again.
1160     node->shape_status(luci::ShapeStatus::UNDEFINED);
1161
1162     auto post_trans = create_post_transpose(node);
1163     loco::replace(node).with(post_trans);
1164
1165     post_trans->a(node);
1166     return true;
1167   }
1168 };
1169
1170 } // namespace
1171
1172 namespace luci
1173 {
1174
1175 bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
1176 {
1177   LOGGER(l);
1178   INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl;
1179
1180   // Annotate NHWC operators
1181   // NHWC operators are detected by pattern matching
1182   //
1183   // Pattern
1184   //    pre-Transose (or pre-Reshape) + [intermediate Ops] + post-Transpose (or post-Reshape)
1185   //
1186   // [intermediate Ops] are annotated as NHWC
1187   //
1188   // NOTE A single pre-Transpose/Reshape can have multiple post-Transpose/Reshape.
1189   // For example,
1190   // pre-Transpose --- [intermediate Ops] --- post-Transpose
1191   //                |
1192   //                +--[intermediate Ops] --- post-Transpose
1193   for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
1194   {
1195     if (has_data_format(node))
1196       continue;
1197
1198     if (is_pre_transpose(node) || is_pre_reshape(node))
1199     {
1200       // For recursive call of lambda
1201       std::function<void(loco::Node *)> set_data_format_to_succs;
1202       set_data_format_to_succs = [&](loco::Node *n) {
1203         for (auto succ : loco::succs(n))
1204         {
1205           // Exit condition
1206           if (is_post_transpose(succ) || is_post_reshape(succ))
1207             continue;
1208
1209           if (not has_data_format(succ))
1210           {
1211             set_data_format(succ, DataFormat::NHWC);
1212           }
1213
1214           set_data_format_to_succs(succ);
1215         }
1216       };
1217
1218       set_data_format_to_succs(node);
1219     }
1220   }
1221
1222   // Annotate NCHW operators
1223   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1224   {
1225     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1226     switch (circle_node->opcode())
1227     {
1228       // List of supported Ops
1229       case luci::CircleOpcode::CIRCLEINPUT:
1230         if (!_preserve_input && !has_data_format(node))
1231         {
1232           set_data_format(node, DataFormat::NCHW);
1233         }
1234         break;
1235       case luci::CircleOpcode::CIRCLEOUTPUT:
1236         if (!_preserve_output && !has_data_format(node))
1237         {
1238           set_data_format(node, DataFormat::NCHW);
1239         }
1240         break;
1241       case luci::CircleOpcode::ADD:
1242       case luci::CircleOpcode::CONCATENATION:
1243       case luci::CircleOpcode::LEAKY_RELU:
1244       case luci::CircleOpcode::LOGISTIC:
1245       case luci::CircleOpcode::MAXIMUM:
1246       case luci::CircleOpcode::MEAN:
1247       case luci::CircleOpcode::MINIMUM:
1248       case luci::CircleOpcode::MUL:
1249       case luci::CircleOpcode::NEG:
1250       case luci::CircleOpcode::PAD:
1251       case luci::CircleOpcode::PADV2:
1252       case luci::CircleOpcode::RELU:
1253       case luci::CircleOpcode::RELU6:
1254       case luci::CircleOpcode::RSQRT:
1255       case luci::CircleOpcode::SQUARED_DIFFERENCE:
1256       case luci::CircleOpcode::SUB:
1257         if (!has_data_format(node))
1258         {
1259           set_data_format(node, DataFormat::NCHW);
1260         }
1261         break;
1262       default:
1263         break;
1264     }
1265   }
1266
1267   bool changed = false;
1268   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1269   {
1270     if (!has_data_format(node))
1271     {
1272       // Unsupported Op
1273       continue;
1274     }
1275     else if (get_data_format(node) == DataFormat::NHWC)
1276     {
1277       // Already converted to NHWC
1278       continue;
1279     }
1280     else if (has_dynamic_shape(node))
1281     {
1282       // This pass only works for static-shaped node
1283       INFO(l) << "Skip the node with a dynamic shape." << std::endl;
1284       continue;
1285     }
1286     else
1287     {
1288       ConvertNCHWToNHWC converter;
1289       auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1290       if (circle_node->rank() != 4)
1291       {
1292         // TODO replace the check above with the input rank check, and remove the condition below
1293         if (not dynamic_cast<luci::CircleMean *>(node))
1294           continue;
1295       }
1296
1297       if (circle_node->accept(&converter))
1298       {
1299         set_data_format(node, DataFormat::NHWC);
1300         changed = true;
1301       }
1302       else
1303       {
1304         continue;
1305       }
1306     }
1307   }
1308
1309   INFO(l) << "ConvertNCHWToNHWCPass End" << std::endl;
1310   return changed;
1311 }
1312
1313 } // namespace luci