2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
18 #include "CircleOptimizerUtils.h"
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/Service/Nodes/CircleConst.h>
31 bool is_same_shape(const luci::CircleNode *node, const std::vector<loco::Dimension> &shape)
36 if (shape.size() != node->rank())
39 for (uint32_t i = 0; i < shape.size(); i++)
41 if (not(node->dim(i) == shape[i]))
54 * @brief Set annotation for DataFormat (NCHW, NHWC)
56 * @note DataFormatAnnotation will live longer than this Pass (until the
57 * annotated loco::Node is erased). So, do not use large data in the
58 * annotation to avoid excessive memory usage.
60 class DataFormatAnnotation final : public loco::NodeAnnotation
63 DataFormatAnnotation(const DataFormat &format) : _format{format}
69 const DataFormat &format(void) const { return _format; }
75 void set_data_format(loco::Node *node, const DataFormat &format)
77 node->annot(std::make_unique<DataFormatAnnotation>(format));
80 DataFormat get_data_format(loco::Node *node)
82 assert(node->annot<DataFormatAnnotation>() != nullptr);
83 return node->annot<DataFormatAnnotation>()->format();
86 bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; }
88 bool check_4d_transpose(loco::Node *node, const std::vector<int32_t> indices)
90 assert(indices.size() == 4);
92 auto trans = dynamic_cast<luci::CircleTranspose *>(node);
96 if (not trans->perm())
99 auto perm = dynamic_cast<luci::CircleConst *>(trans->perm());
100 // Only const perm is supported
104 if (perm->dtype() != loco::DataType::S32)
107 if (perm->size<loco::DataType::S32>() != 4)
110 for (uint32_t i = 0; i < 4; i++)
112 if (perm->at<loco::DataType::S32>(i) != indices[i])
119 luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
120 const std::vector<int32_t> indices)
122 assert(indices.size() == 4);
124 auto name = node->name();
125 assert(name.length() > 0);
127 auto perm = node->graph()->nodes()->create<luci::CircleConst>();
128 perm->dtype(loco::DataType::S32);
129 perm->size<loco::DataType::S32>(4);
132 for (uint32_t i = 0; i < 4; i++)
133 perm->at<loco::DataType::S32>(i) = indices[i];
134 perm->shape_status(luci::ShapeStatus::VALID);
136 auto make_string = [](const std::vector<int32_t> &nums) {
138 for (auto num : nums)
140 if (str.length() > 0)
142 str += std::to_string(num);
147 auto str_indices = make_string(indices);
149 perm->name(name + "/Transpose_" + str_indices + "/perm");
151 auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
153 trans->name(name + "/Transpose_" + str_indices);
154 luci::add_origin(trans, luci::get_origin(node));
159 luci::CircleTranspose *create_Nd_transpose(luci::CircleNode *node,
160 const std::vector<int32_t> indices)
162 auto name = node->name();
163 assert(name.length() > 0);
165 auto perm = node->graph()->nodes()->create<luci::CircleConst>();
166 perm->dtype(loco::DataType::S32);
167 perm->size<loco::DataType::S32>(indices.size());
169 perm->dim(0) = indices.size();
170 for (uint32_t i = 0; i < indices.size(); i++)
171 perm->at<loco::DataType::S32>(i) = indices[i];
172 perm->shape_status(luci::ShapeStatus::VALID);
174 auto make_string = [](const std::vector<int32_t> &nums) {
176 for (auto num : nums)
178 if (str.length() > 0)
180 str += std::to_string(num);
185 auto str_indices = make_string(indices);
187 perm->name(name + "/Transpose_" + str_indices + "/perm");
189 auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
191 trans->name(name + "/Transpose_" + str_indices);
192 luci::add_origin(trans, luci::get_origin(node));
197 int32_t nchw_axis_to_nhwc(int32_t axis)
199 uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4);
200 static const uint32_t to_nhwc[4] = {0, 3, 1, 2};
202 throw std::runtime_error("Concat axis must be in range [-4, 4)");
203 return to_nhwc[pos_axis];
206 luci::CircleTranspose *create_post_transpose(luci::CircleNode *node)
208 return create_4d_transpose(node, {0, 3, 1, 2});
211 luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node)
213 return create_4d_transpose(node, {0, 2, 3, 1});
216 bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
218 assert(indices.size() == 4); // FIX_CALLER_UNLESS
220 auto reshape = dynamic_cast<luci::CircleReshape *>(node);
224 if (reshape->rank() != 4)
227 auto input = loco::must_cast<luci::CircleNode *>(reshape->tensor());
228 if (input->shape_status() != luci::ShapeStatus::VALID)
231 if (reshape->shape_status() != luci::ShapeStatus::VALID)
234 if (!(input->dim(0) == reshape->dim(indices[0])) ||
235 !(input->dim(1) == reshape->dim(indices[1])) ||
236 !(input->dim(2) == reshape->dim(indices[2])) || !(input->dim(3) == reshape->dim(indices[3])))
242 // Check if Reshape that converts NCHW -> NHWC
243 bool is_pre_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 3, 1, 2}); }
245 // Check if Reshape that converts NHWC -> NCHW
246 bool is_post_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 2, 3, 1}); }
248 bool is_post_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 3, 1, 2}); }
250 bool is_pre_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 2, 3, 1}); }
252 uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices)
254 return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
255 dimension.dim(3).value() +
256 indices[1] * dimension.dim(2).value() * dimension.dim(3).value() +
257 indices[2] * dimension.dim(3).value() + indices[3];
260 luci::CircleConst *create_NHWC_paddings(luci::CircleConst *paddings)
262 // paddings shape is (4,2) (it was checked by is_NCHW)
263 assert(paddings != nullptr);
264 assert(paddings->rank() == 2);
265 assert(paddings->dim(0).value() == 4);
266 assert(paddings->dim(1).value() == 2);
268 // paddings for idx 0~3 are 0 (checked by is_NCHW)
269 assert(paddings->at<loco::DataType::S32>(0) == 0);
270 assert(paddings->at<loco::DataType::S32>(1) == 0);
271 assert(paddings->at<loco::DataType::S32>(2) == 0);
272 assert(paddings->at<loco::DataType::S32>(3) == 0);
274 auto name = paddings->name();
275 assert(name.length() > 0);
277 auto nhwc_paddings = paddings->graph()->nodes()->create<luci::CircleConst>();
278 nhwc_paddings->dtype(loco::DataType::S32);
279 nhwc_paddings->shape({4, 2});
280 nhwc_paddings->shape_status(luci::ShapeStatus::VALID);
281 nhwc_paddings->size<loco::DataType::S32>(4 * 2);
282 nhwc_paddings->name(name + "_NHWC");
284 for (uint32_t dim = 0; dim < 4; dim++)
286 for (uint32_t i = 0; i < 2; i++)
292 // get third dimension (H in NCHW)
293 data = paddings->at<loco::DataType::S32>(2 * 2 + i);
297 // get fourth dimension (W in NCHW)
298 data = paddings->at<loco::DataType::S32>(3 * 2 + i);
301 nhwc_paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
304 return nhwc_paddings;
307 luci::CircleConst *create_NHWC_rindices(luci::CircleConst *rindices)
309 assert(rindices != nullptr); // FIX_CALLER_UNLESS
311 if (rindices->dtype() != loco::DataType::S32)
314 auto nhwc_rindices = luci::clone(rindices);
315 auto name = rindices->name();
316 assert(name.length() > 0); // FIX_CALLER_UNLESS
317 nhwc_rindices->name(name + "_NHWC");
319 auto size = nhwc_rindices->size<loco::DataType::S32>();
320 for (uint32_t i = 0; i < size; i++)
322 nhwc_rindices->at<loco::DataType::S32>(i) =
323 nchw_axis_to_nhwc(rindices->at<loco::DataType::S32>(i));
326 return nhwc_rindices;
329 luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant)
332 assert(constant->rank() == 4);
334 // TODO: Support non-float types
335 if (constant->dtype() != loco::DataType::FLOAT32)
337 INFO(l) << "Non-float type constant: " << constant->name() << std::endl;
341 loco::TensorShape nchw_dimension{constant->dim(0), constant->dim(1), constant->dim(2),
343 loco::TensorShape nhwc_dimension{constant->dim(0), constant->dim(2), constant->dim(3),
346 auto name = constant->name();
347 assert(name.length() > 0);
349 auto nhwc_const = constant->graph()->nodes()->create<luci::CircleConst>();
350 nhwc_const->dtype(constant->dtype());
352 nhwc_const->dim(0).set(constant->dim(0).value());
353 nhwc_const->dim(1).set(constant->dim(2).value());
354 nhwc_const->dim(2).set(constant->dim(3).value());
355 nhwc_const->dim(3).set(constant->dim(1).value());
356 nhwc_const->shape_status(luci::ShapeStatus::VALID);
357 nhwc_const->size<loco::DataType::FLOAT32>(constant->size<loco::DataType::FLOAT32>());
358 nhwc_const->name(name + "_NHWC");
360 for (uint32_t n = 0; n < nchw_dimension.dim(0).value(); n++)
362 for (uint32_t c = 0; c < nchw_dimension.dim(1).value(); c++)
364 for (uint32_t h = 0; h < nchw_dimension.dim(2).value(); h++)
366 for (uint32_t w = 0; w < nchw_dimension.dim(3).value(); w++)
368 uint32_t nchw_indices[4] = {n, c, h, w};
369 uint32_t nhwc_indices[4] = {n, h, w, c};
371 constant->at<loco::DataType::FLOAT32>(cal_offset(nchw_dimension, nchw_indices));
372 nhwc_const->at<loco::DataType::FLOAT32>(cal_offset(nhwc_dimension, nhwc_indices)) = data;
380 // NOTE Following conditions can be extended later
382 // Find PAD with an NCHW pattern described below
383 // - Paddings shape : [4, 2]
384 // - Paddings value : [[0, 0], [0, 0], [h_t, h_b], [w_t, w_b]]]
385 bool is_NCHW(const luci::CirclePad *node)
387 const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
388 // Non-const paddings is not supported
389 if (paddings == nullptr)
392 if (paddings->rank() != 2)
395 if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
398 // Only check the first two dimensions
399 for (uint32_t dim = 0; dim < 2; dim++)
401 for (uint32_t i = 0; i < 2; i++)
403 auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
412 // NOTE Copied from is_NCHW(CirclePad)
413 bool is_NCHW(const luci::CirclePadV2 *node)
415 const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
416 // Non-const paddings is not supported
417 if (paddings == nullptr)
420 if (paddings->rank() != 2)
423 if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
426 // Only check the first two dimensions
427 for (uint32_t dim = 0; dim < 2; dim++)
429 for (uint32_t i = 0; i < 2; i++)
431 auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
440 // NOTE Following conditions can be extended later
441 // NOTE Used for Maximum, Miminum as ReLU/ReLU6
443 // Find T with an NCHW pattern described below
444 // - Input (non-constant) shape : [N, C, H, W]
445 // - Input (constant) shape : [1] or []
446 // - Output shape : [N, C, H, W]
448 bool is_NCHW_with_s_const(const T *node, luci::CircleNode *&pred_node,
449 luci::CircleConst *&comp_const)
451 auto x = dynamic_cast<luci::CircleConst *>(node->x());
452 auto y = dynamic_cast<luci::CircleConst *>(node->y());
454 if (x != nullptr && y == nullptr)
456 pred_node = loco::must_cast<luci::CircleNode *>(node->y());
459 else if (x == nullptr && y != nullptr)
461 pred_node = loco::must_cast<luci::CircleNode *>(node->x());
466 // Ignore if T does not have a comp_const input.
470 if (pred_node->rank() != 4)
474 const auto const_rank = comp_const->rank();
475 if (const_rank == 0 || (const_rank == 1 && comp_const->dim(0).value() == 1))
480 // NOTE Following conditions can be extended later
482 // Find MUL with an NCHW pattern described below
483 // - Input (non-constant) shape : [N, C, H, W]
484 // - Input (constant) shape : [1, C, 1, 1], [N, C, H, W] or a scalar (1)
485 // - Output shape : [N, C, H, W]
486 bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
487 luci::CircleConst *&multiplier)
489 auto x = dynamic_cast<luci::CircleConst *>(node->x());
490 auto y = dynamic_cast<luci::CircleConst *>(node->y());
492 if (x != nullptr && y == nullptr)
494 pred_node = loco::must_cast<luci::CircleNode *>(node->y());
497 else if (x == nullptr && y != nullptr)
499 pred_node = loco::must_cast<luci::CircleNode *>(node->x());
504 // Ignore if MUL does not have a multiplier input.
508 if (pred_node->rank() != 4)
511 const auto const_rank = multiplier->rank();
512 // Support Rank 4 or scalar (rank 0 or 1)
513 if (const_rank != 4 && const_rank != 0 && const_rank != 1)
516 const auto input_cdim = pred_node->dim(1);
517 const auto output_cdim = node->dim(1);
521 bool supported_shape = false;
523 // Check multiplier is (1, C, 1, 1)
524 if (is_same_shape(multiplier, {1, node->dim(1), 1, 1}))
525 supported_shape = true;
527 // Check multiplier is (N, C, H, W)
528 if (is_same_shape(multiplier, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
529 supported_shape = true;
531 return supported_shape;
533 if (input_cdim == output_cdim)
539 // We assume ADD with const input is NCHW if,
540 // Input shape: (N, C, H, W)
541 // Output shape: (N, C, H, W)
542 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
543 // 2. Input, Output, Const have the same C.
544 bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node,
545 luci::CircleConst *&beta)
547 auto x = dynamic_cast<luci::CircleConst *>(node->x());
548 auto y = dynamic_cast<luci::CircleConst *>(node->y());
550 if (x != nullptr && y == nullptr)
552 pred_node = loco::must_cast<luci::CircleNode *>(node->y());
555 else if (x == nullptr && y != nullptr)
557 pred_node = loco::must_cast<luci::CircleNode *>(node->x());
562 // Ignore if ADD does not have a constant input.
566 if (pred_node->rank() != 4)
569 const auto const_rank = beta->rank();
570 // Support Rank 4 or scalar (rank 0 or 1)
571 if (const_rank != 4 && const_rank != 0 && const_rank != 1)
574 const auto input_cdim = pred_node->dim(1);
575 const auto output_cdim = node->dim(1);
579 bool supported_shape = false;
581 // Check beta is (1, C, 1, 1)
582 if (is_same_shape(beta, {1, node->dim(1), 1, 1}))
583 supported_shape = true;
585 // Check beta is (N, C, H, W)
586 if (is_same_shape(beta, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
587 supported_shape = true;
589 return supported_shape;
591 if (input_cdim == output_cdim)
597 // We assume SUB with const input is NCHW if,
598 // Input shape: (N, C, H, W)
599 // Output shape: (N, C, H, W)
600 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
601 // 2. Input, Output, Const have the same C.
602 bool is_NCHW_with_const(const luci::CircleSub *node, const luci::CircleNode *pred_node,
603 const luci::CircleConst *subtract)
605 assert(pred_node != nullptr);
606 assert(subtract != nullptr);
608 if (pred_node->rank() != 4)
611 const auto const_rank = subtract->rank();
612 // Support Rank 4 or scalar (rank 0 or 1)
613 if (const_rank != 4 && const_rank != 0 && const_rank != 1)
616 const auto input_cdim = pred_node->dim(1);
617 const auto output_cdim = node->dim(1);
621 bool supported_shape = false;
623 // Check subtract is (1, C, 1, 1)
624 if (is_same_shape(subtract, {1, node->dim(1), 1, 1}))
625 supported_shape = true;
627 // Check subtract is (N, C, H, W)
628 if (is_same_shape(subtract, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
629 supported_shape = true;
631 return supported_shape;
633 if (input_cdim == output_cdim)
639 template <class T> bool convert_unary_features(T *node)
641 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->features());
642 auto pre_trans = create_pre_transpose(node);
643 pre_trans->a(pred_node);
644 node->features(pre_trans);
646 // Do shape inference for this node again.
647 node->shape_status(luci::ShapeStatus::UNDEFINED);
649 auto post_trans = create_post_transpose(node);
650 loco::replace(node).with(post_trans);
657 template <class T> bool convert_unary_x(T *node)
659 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
660 auto pre_trans = create_pre_transpose(node);
661 pre_trans->a(pred_node);
664 // Do shape inference for this node again.
665 node->shape_status(luci::ShapeStatus::UNDEFINED);
667 auto post_trans = create_post_transpose(node);
668 loco::replace(node).with(post_trans);
675 class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
678 bool visit(luci::CircleNode *node)
680 throw std::runtime_error(node->name() + " is an unsupported operator.");
683 bool visit(luci::CircleInput *node)
685 const auto n = node->dim(0);
686 const auto c = node->dim(1);
687 const auto h = node->dim(2);
688 const auto w = node->dim(3);
694 // Do shape inference for this node again.
695 node->shape_status(luci::ShapeStatus::UNDEFINED);
697 // Insert post-tranpose
698 auto post_trans = create_post_transpose(node);
699 loco::replace(node).with(post_trans);
703 // Update graph input
704 auto graph_inputs = node->graph()->inputs();
705 auto graph_input = graph_inputs->at(node->index());
706 graph_input->shape({n, h, w, c});
711 bool visit(luci::CircleOutput *node)
713 // Insert pre-transpose
714 auto pre_trans = create_pre_transpose(node);
715 pre_trans->a(node->from());
717 node->from(pre_trans);
719 // Do shape inference for this node again.
720 node->shape_status(luci::ShapeStatus::UNDEFINED);
722 // Update graph output
723 const auto n = node->dim(0).value();
724 const auto c = node->dim(1).value();
725 const auto h = node->dim(2).value();
726 const auto w = node->dim(3).value();
728 auto graph_outputs = node->graph()->outputs();
729 auto graph_output = graph_outputs->at(node->index());
730 graph_output->shape({n, h, w, c});
735 bool visit(luci::CircleAdd *node)
737 luci::CircleNode *pred_node = nullptr;
738 luci::CircleConst *beta = nullptr;
740 if (is_NCHW_with_const(node, pred_node, beta))
742 auto pre_trans = create_pre_transpose(node);
743 pre_trans->a(pred_node);
745 if (beta->rank() == 4)
747 auto nhwc_const = create_NHWC_from_NCHW(beta);
748 if (nhwc_const == nullptr)
755 else if (beta == nullptr)
757 // Both inputs are not constant.
758 // In this case, we cannot distinguish NCHW from NHWC,
759 // so just insert Transpose Ops.
760 auto pre_trans_x = create_pre_transpose(node);
761 pre_trans_x->a(node->x());
762 node->x(pre_trans_x);
764 auto pre_trans_y = create_pre_transpose(node);
765 pre_trans_y->a(node->y());
766 node->y(pre_trans_y);
773 // Do shape inference for this node again.
774 node->shape_status(luci::ShapeStatus::UNDEFINED);
776 auto post_trans = create_post_transpose(node);
777 loco::replace(node).with(post_trans);
783 bool visit(luci::CircleConcatenation *node)
785 const auto num_values = node->numValues();
786 for (uint32_t i = 0; i < num_values; i++)
788 auto pred_node = loco::must_cast<luci::CircleNode *>(node->values(i));
789 auto pre_trans = create_pre_transpose(node);
790 pre_trans->a(pred_node);
791 node->values(i, pre_trans);
794 // Do shape inference for this node again.
795 node->shape_status(luci::ShapeStatus::UNDEFINED);
797 node->axis(nchw_axis_to_nhwc(node->axis()));
799 auto post_trans = create_post_transpose(node);
800 loco::replace(node).with(post_trans);
807 bool visit(luci::CircleLeakyRelu *node)
809 return convert_unary_features<luci::CircleLeakyRelu>(node);
812 bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
814 bool visit(luci::CircleMaximum *node)
816 luci::CircleNode *pred_node = nullptr;
817 luci::CircleConst *comp_constant = nullptr;
819 if (is_NCHW_with_s_const<luci::CircleMaximum>(node, pred_node, comp_constant))
821 auto pre_trans = create_pre_transpose(node);
822 pre_trans->a(pred_node);
827 // TODO support other cases
831 // Do shape inference for this node again.
832 node->shape_status(luci::ShapeStatus::UNDEFINED);
834 auto post_trans = create_post_transpose(node);
835 loco::replace(node).with(post_trans);
841 bool visit(luci::CircleMean *node)
843 auto input = loco::must_cast<luci::CircleNode *>(node->input());
844 if (input->rank() != 4)
847 auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
851 auto nhwc_rindices = create_NHWC_rindices(rindices);
852 if (not nhwc_rindices)
855 auto pre_trans = create_pre_transpose(node);
857 node->input(pre_trans);
859 // Do shape inference for this node again.
860 node->shape_status(luci::ShapeStatus::UNDEFINED);
862 node->reduction_indices(nhwc_rindices);
864 if (node->keep_dims())
866 auto post_trans = create_post_transpose(node);
867 loco::replace(node).with(post_trans);
874 // node->keep_dims() == false
875 // 1D output never needs a transpose
876 if (node->rank() <= 1)
879 std::vector<bool> reduced_dims_nhwc(4, false);
880 uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
882 for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
884 reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
887 // if channel dimension has been reduced, we don't need a transpose
888 if (reduced_dims_nhwc[3])
891 // likewise, if both space dimensions are reduced, no transpose is needed
892 if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
895 std::vector<int32_t> post_trans_ind;
896 // case 1: only N is reduced
897 if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
898 post_trans_ind = {2, 0, 1};
900 // case 2: only H or W is reduced
901 if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
902 post_trans_ind = {0, 2, 1};
904 // case 3: N and either H or W are reduced
905 if (num_reduced_indices == 2)
906 post_trans_ind = {1, 0};
908 auto post_trans = create_Nd_transpose(node, post_trans_ind);
909 loco::replace(node).with(post_trans);
916 bool visit(luci::CircleMinimum *node)
918 luci::CircleNode *pred_node = nullptr;
919 luci::CircleConst *comp_constant = nullptr;
921 if (is_NCHW_with_s_const<luci::CircleMinimum>(node, pred_node, comp_constant))
923 auto pre_trans = create_pre_transpose(node);
924 pre_trans->a(pred_node);
929 // TODO support other cases
933 // Do shape inference for this node again.
934 node->shape_status(luci::ShapeStatus::UNDEFINED);
936 auto post_trans = create_post_transpose(node);
937 loco::replace(node).with(post_trans);
943 bool visit(luci::CircleMul *node)
947 luci::CircleNode *pred_node = nullptr;
948 luci::CircleConst *multiplier = nullptr;
950 if (is_NCHW_with_const(node, pred_node, multiplier))
952 auto pre_trans = create_pre_transpose(node);
953 pre_trans->a(pred_node);
956 if (multiplier->rank() == 4)
958 auto nhwc_const = create_NHWC_from_NCHW(multiplier);
962 else if (multiplier == nullptr)
964 // Only support for input rank 4
965 auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
966 if (input_x->rank() != 4)
968 auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
969 if (input_y->rank() != 4)
972 auto pre_trans_x = create_pre_transpose(node);
973 pre_trans_x->a(input_x);
974 node->x(pre_trans_x);
976 auto pre_trans_y = create_pre_transpose(node);
977 pre_trans_y->a(input_y);
978 node->y(pre_trans_y);
985 // Do shape inference for this node again.
986 node->shape_status(luci::ShapeStatus::UNDEFINED);
988 auto post_trans = create_post_transpose(node);
989 loco::replace(node).with(post_trans);
995 bool visit(luci::CircleNeg *node) { return convert_unary_x<luci::CircleNeg>(node); }
997 bool visit(luci::CirclePad *node)
1002 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1003 auto pre_trans = create_pre_transpose(node);
1004 pre_trans->a(pred_node);
1005 node->input(pre_trans);
1007 auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1008 const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1009 node->paddings(nhwc_paddings);
1011 // Do shape inference for this node again.
1012 node->shape_status(luci::ShapeStatus::UNDEFINED);
1014 auto post_trans = create_post_transpose(node);
1015 loco::replace(node).with(post_trans);
1017 post_trans->a(node);
1022 bool visit(luci::CirclePadV2 *node)
1027 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1028 auto pre_trans = create_pre_transpose(node);
1029 pre_trans->a(pred_node);
1030 node->input(pre_trans);
1032 auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1033 const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1034 node->paddings(nhwc_paddings);
1036 // Do shape inference for this node again.
1037 node->shape_status(luci::ShapeStatus::UNDEFINED);
1039 auto post_trans = create_post_transpose(node);
1040 loco::replace(node).with(post_trans);
1042 post_trans->a(node);
1047 bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
1049 bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
1051 bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
1053 bool visit(luci::CircleSquaredDifference *node)
1055 // TODO support CircleConst input
1056 if (dynamic_cast<luci::CircleConst *>(node->x()) != nullptr)
1058 if (dynamic_cast<luci::CircleConst *>(node->y()) != nullptr)
1061 auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1062 if (input_x->rank() != 4)
1064 auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1065 if (input_y->rank() != 4)
1068 auto pre_trans_x = create_pre_transpose(node);
1069 pre_trans_x->a(input_x);
1070 node->x(pre_trans_x);
1072 auto pre_trans_y = create_pre_transpose(node);
1073 pre_trans_y->a(input_y);
1074 node->y(pre_trans_y);
1076 // Do shape inference for this node again.
1077 node->shape_status(luci::ShapeStatus::UNDEFINED);
1079 auto post_trans = create_post_transpose(node);
1080 loco::replace(node).with(post_trans);
1082 post_trans->a(node);
1086 bool visit(luci::CircleSub *node)
1088 luci::CircleNode *pred_node = nullptr;
1089 luci::CircleConst *subtract = nullptr;
1091 auto const_x = dynamic_cast<luci::CircleConst *>(node->x());
1092 auto const_y = dynamic_cast<luci::CircleConst *>(node->y());
1094 if (const_x != nullptr && const_y == nullptr)
1096 // case of subtract - pred_node
1097 pred_node = loco::must_cast<luci::CircleNode *>(node->y());
1100 if (!is_NCHW_with_const(node, pred_node, subtract))
1103 auto pre_trans = create_pre_transpose(node);
1104 pre_trans->a(pred_node);
1106 if (subtract->rank() == 4)
1108 auto nhwc_const = create_NHWC_from_NCHW(subtract);
1109 if (nhwc_const == nullptr)
1111 node->x(nhwc_const);
1115 else if (const_x == nullptr && const_y != nullptr)
1117 // case of pred_node - subtract
1118 pred_node = loco::must_cast<luci::CircleNode *>(node->x());
1121 if (!is_NCHW_with_const(node, pred_node, subtract))
1124 auto pre_trans = create_pre_transpose(node);
1125 pre_trans->a(pred_node);
1127 if (subtract->rank() == 4)
1129 auto nhwc_const = create_NHWC_from_NCHW(subtract);
1130 if (nhwc_const == nullptr)
1132 node->y(nhwc_const);
1137 else if (const_x == nullptr && const_y == nullptr)
1139 // Both inputs are not constant.
1140 // In this case, we cannot distinguish NCHW from NHWC,
1141 // so just insert Transpose Ops.
1142 // Only support for input rank 4.
1143 auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1144 if (input_x->rank() != 4)
1146 auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1147 if (input_y->rank() != 4)
1150 auto pre_trans_x = create_pre_transpose(node);
1151 pre_trans_x->a(input_x);
1152 node->x(pre_trans_x);
1154 auto pre_trans_y = create_pre_transpose(node);
1155 pre_trans_y->a(input_y);
1156 node->y(pre_trans_y);
1159 // Do shape inference for this node again.
1160 node->shape_status(luci::ShapeStatus::UNDEFINED);
1162 auto post_trans = create_post_transpose(node);
1163 loco::replace(node).with(post_trans);
1165 post_trans->a(node);
1175 bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
1178 INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl;
1180 // Annotate NHWC operators
1181 // NHWC operators are detected by pattern matching
1184 // pre-Transose (or pre-Reshape) + [intermediate Ops] + post-Transpose (or post-Reshape)
1186 // [intermediate Ops] are annotated as NHWC
1188 // NOTE A single pre-Transpose/Reshape can have multiple post-Transpose/Reshape.
1190 // pre-Transpose --- [intermediate Ops] --- post-Transpose
1192 // +--[intermediate Ops] --- post-Transpose
1193 for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
1195 if (has_data_format(node))
1198 if (is_pre_transpose(node) || is_pre_reshape(node))
1200 // For recursive call of lambda
1201 std::function<void(loco::Node *)> set_data_format_to_succs;
1202 set_data_format_to_succs = [&](loco::Node *n) {
1203 for (auto succ : loco::succs(n))
1206 if (is_post_transpose(succ) || is_post_reshape(succ))
1209 if (not has_data_format(succ))
1211 set_data_format(succ, DataFormat::NHWC);
1214 set_data_format_to_succs(succ);
1218 set_data_format_to_succs(node);
1222 // Annotate NCHW operators
1223 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1225 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1226 switch (circle_node->opcode())
1228 // List of supported Ops
1229 case luci::CircleOpcode::CIRCLEINPUT:
1230 if (!_preserve_input && !has_data_format(node))
1232 set_data_format(node, DataFormat::NCHW);
1235 case luci::CircleOpcode::CIRCLEOUTPUT:
1236 if (!_preserve_output && !has_data_format(node))
1238 set_data_format(node, DataFormat::NCHW);
1241 case luci::CircleOpcode::ADD:
1242 case luci::CircleOpcode::CONCATENATION:
1243 case luci::CircleOpcode::LEAKY_RELU:
1244 case luci::CircleOpcode::LOGISTIC:
1245 case luci::CircleOpcode::MAXIMUM:
1246 case luci::CircleOpcode::MEAN:
1247 case luci::CircleOpcode::MINIMUM:
1248 case luci::CircleOpcode::MUL:
1249 case luci::CircleOpcode::NEG:
1250 case luci::CircleOpcode::PAD:
1251 case luci::CircleOpcode::PADV2:
1252 case luci::CircleOpcode::RELU:
1253 case luci::CircleOpcode::RELU6:
1254 case luci::CircleOpcode::RSQRT:
1255 case luci::CircleOpcode::SQUARED_DIFFERENCE:
1256 case luci::CircleOpcode::SUB:
1257 if (!has_data_format(node))
1259 set_data_format(node, DataFormat::NCHW);
1267 bool changed = false;
1268 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1270 if (!has_data_format(node))
1275 else if (get_data_format(node) == DataFormat::NHWC)
1277 // Already converted to NHWC
1280 else if (has_dynamic_shape(node))
1282 // This pass only works for static-shaped node
1283 INFO(l) << "Skip the node with a dynamic shape." << std::endl;
1288 ConvertNCHWToNHWC converter;
1289 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1290 if (circle_node->rank() != 4)
1292 // TODO replace the check above with the input rank check, and remove the condition below
1293 if (not dynamic_cast<luci::CircleMean *>(node))
1297 if (circle_node->accept(&converter))
1299 set_data_format(node, DataFormat::NHWC);
1309 INFO(l) << "ConvertNCHWToNHWCPass End" << std::endl;