2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
18 #include "CircleOptimizerUtils.h"
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/Service/Nodes/CircleConst.h>
31 // Return true if from can be broadcasted to to
32 // to's shape is [N, C, H, W]
33 bool broadcastable(const luci::CircleConst *from, const luci::CircleNode *to)
35 assert(to->rank() == 4); // FIX_CALLER_UNLESS
37 const auto from_rank = from->rank();
41 // Scalar is always broadcastable
45 for (uint32_t i = 1; i <= from_rank; i++)
47 auto to_index = 4 - i;
48 auto from_index = from_rank - i;
50 if (from->dim(from_index).value() != to->dim(to_index).value() and
51 from->dim(from_index).value() != 1)
58 // Return node with rank 4
59 // node should have rank less than or equal to 4
60 // 1 is inserted to the front of shape if rank is less than 4
61 // For example, [2] -> [1, 1, 1, 2]
62 luci::CircleConst *expand_to_rank_4(luci::CircleConst *node)
64 auto original_rank = node->rank();
66 assert(original_rank <= 4); // FIX_CALLER_UNLESS
68 if (original_rank == 4)
71 std::vector<uint32_t> original_shape;
72 for (uint32_t i = 0; i < original_rank; i++)
74 original_shape.emplace_back(node->dim(i).value());
77 auto cloned = luci::clone(node);
78 cloned->name(cloned->name() + "_rank4");
81 for (uint32_t i = 0; i < (4 - original_rank); i++)
84 for (uint32_t i = 0; i < original_rank; i++)
85 cloned->dim(i + (4 - original_rank)) = original_shape.at(i);
90 bool is_output(const loco::Node *node)
92 auto cnode = loco::must_cast<const luci::CircleNode *>(node);
93 auto opcode = cnode->opcode();
94 if (opcode == luci::CircleOpcode::CIRCLEOUTPUT ||
95 opcode == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
101 bool is_same_shape(const luci::CircleNode *node, const std::vector<loco::Dimension> &shape)
106 if (shape.size() != node->rank())
109 for (uint32_t i = 0; i < shape.size(); i++)
111 if (not(node->dim(i) == shape[i]))
117 enum class DataFormat
124 * @brief Set annotation for DataFormat (NCHW, NHWC)
126 * @note DataFormatAnnotation will live longer than this Pass (until the
127 * annotated loco::Node is erased). So, do not use large data in the
128 * annotation to avoid excessive memory usage.
130 class DataFormatAnnotation final : public loco::NodeAnnotation
133 DataFormatAnnotation(const DataFormat &format) : _format{format}
139 const DataFormat &format(void) const { return _format; }
145 void set_data_format(loco::Node *node, const DataFormat &format)
147 node->annot(std::make_unique<DataFormatAnnotation>(format));
150 DataFormat get_data_format(loco::Node *node)
152 assert(node->annot<DataFormatAnnotation>() != nullptr);
153 return node->annot<DataFormatAnnotation>()->format();
156 bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; }
158 bool check_4d_transpose(loco::Node *node, const std::vector<int32_t> indices)
160 assert(indices.size() == 4);
162 auto trans = dynamic_cast<luci::CircleTranspose *>(node);
166 if (not trans->perm())
169 auto perm = dynamic_cast<luci::CircleConst *>(trans->perm());
170 // Only const perm is supported
174 if (perm->dtype() != loco::DataType::S32)
177 if (perm->size<loco::DataType::S32>() != 4)
180 for (uint32_t i = 0; i < 4; i++)
182 if (perm->at<loco::DataType::S32>(i) != indices[i])
189 luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
190 const std::vector<int32_t> indices)
192 assert(indices.size() == 4);
194 auto name = node->name();
195 assert(name.length() > 0);
197 auto perm = node->graph()->nodes()->create<luci::CircleConst>();
198 perm->dtype(loco::DataType::S32);
199 perm->size<loco::DataType::S32>(4);
202 for (uint32_t i = 0; i < 4; i++)
203 perm->at<loco::DataType::S32>(i) = indices[i];
204 perm->shape_status(luci::ShapeStatus::VALID);
206 auto make_string = [](const std::vector<int32_t> &nums) {
208 for (auto num : nums)
210 if (str.length() > 0)
212 str += std::to_string(num);
217 auto str_indices = make_string(indices);
219 perm->name(name + "/Transpose_" + str_indices + "/perm");
221 auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
223 trans->name(name + "/Transpose_" + str_indices);
224 luci::add_origin(trans, luci::get_origin(node));
229 luci::CircleTranspose *create_Nd_transpose(luci::CircleNode *node,
230 const std::vector<int32_t> indices)
232 auto name = node->name();
233 assert(name.length() > 0);
235 auto perm = node->graph()->nodes()->create<luci::CircleConst>();
236 perm->dtype(loco::DataType::S32);
237 perm->size<loco::DataType::S32>(indices.size());
239 perm->dim(0) = indices.size();
240 for (uint32_t i = 0; i < indices.size(); i++)
241 perm->at<loco::DataType::S32>(i) = indices[i];
242 perm->shape_status(luci::ShapeStatus::VALID);
244 auto make_string = [](const std::vector<int32_t> &nums) {
246 for (auto num : nums)
248 if (str.length() > 0)
250 str += std::to_string(num);
255 auto str_indices = make_string(indices);
257 perm->name(name + "/Transpose_" + str_indices + "/perm");
259 auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
261 trans->name(name + "/Transpose_" + str_indices);
262 luci::add_origin(trans, luci::get_origin(node));
267 int32_t nchw_axis_to_nhwc(int32_t axis)
269 uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4);
270 static const uint32_t to_nhwc[4] = {0, 3, 1, 2};
272 throw std::runtime_error("Concat axis must be in range [-4, 4)");
273 return to_nhwc[pos_axis];
276 luci::CircleTranspose *create_post_transpose(luci::CircleNode *node)
278 return create_4d_transpose(node, {0, 3, 1, 2});
281 luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node)
283 return create_4d_transpose(node, {0, 2, 3, 1});
286 bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
288 assert(indices.size() == 4); // FIX_CALLER_UNLESS
290 auto reshape = dynamic_cast<luci::CircleReshape *>(node);
294 if (reshape->rank() != 4)
297 auto input = loco::must_cast<luci::CircleNode *>(reshape->tensor());
298 if (input->shape_status() != luci::ShapeStatus::VALID)
301 if (input->rank() != 4)
304 if (reshape->shape_status() != luci::ShapeStatus::VALID)
307 if (!(input->dim(0) == reshape->dim(indices[0])) ||
308 !(input->dim(1) == reshape->dim(indices[1])) ||
309 !(input->dim(2) == reshape->dim(indices[2])) || !(input->dim(3) == reshape->dim(indices[3])))
315 // Check if Reshape that converts NCHW -> NHWC
316 bool is_pre_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 3, 1, 2}); }
318 // Check if Reshape that converts NHWC -> NCHW
319 bool is_post_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 2, 3, 1}); }
321 bool is_post_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 3, 1, 2}); }
323 bool is_pre_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 2, 3, 1}); }
325 uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices)
327 return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
328 dimension.dim(3).value() +
329 indices[1] * dimension.dim(2).value() * dimension.dim(3).value() +
330 indices[2] * dimension.dim(3).value() + indices[3];
333 luci::CircleConst *create_NHWC_paddings(luci::CircleConst *paddings)
335 // paddings shape is (4,2) (it was checked by is_NCHW)
336 assert(paddings != nullptr);
337 assert(paddings->rank() == 2);
338 assert(paddings->dim(0).value() == 4);
339 assert(paddings->dim(1).value() == 2);
341 // paddings for idx 0~3 are 0 (checked by is_NCHW)
342 assert(paddings->at<loco::DataType::S32>(0) == 0);
343 assert(paddings->at<loco::DataType::S32>(1) == 0);
344 assert(paddings->at<loco::DataType::S32>(2) == 0);
345 assert(paddings->at<loco::DataType::S32>(3) == 0);
347 auto name = paddings->name();
348 assert(name.length() > 0);
350 auto nhwc_paddings = paddings->graph()->nodes()->create<luci::CircleConst>();
351 nhwc_paddings->dtype(loco::DataType::S32);
352 nhwc_paddings->shape({4, 2});
353 nhwc_paddings->shape_status(luci::ShapeStatus::VALID);
354 nhwc_paddings->size<loco::DataType::S32>(4 * 2);
355 nhwc_paddings->name(name + "_NHWC");
357 for (uint32_t dim = 0; dim < 4; dim++)
359 for (uint32_t i = 0; i < 2; i++)
365 // get third dimension (H in NCHW)
366 data = paddings->at<loco::DataType::S32>(2 * 2 + i);
370 // get fourth dimension (W in NCHW)
371 data = paddings->at<loco::DataType::S32>(3 * 2 + i);
374 nhwc_paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
377 return nhwc_paddings;
380 luci::CircleConst *create_NHWC_rindices(luci::CircleConst *rindices)
382 assert(rindices != nullptr); // FIX_CALLER_UNLESS
384 if (rindices->dtype() != loco::DataType::S32)
387 auto nhwc_rindices = luci::clone(rindices);
388 auto name = rindices->name();
389 assert(name.length() > 0); // FIX_CALLER_UNLESS
390 nhwc_rindices->name(name + "_NHWC");
392 auto size = nhwc_rindices->size<loco::DataType::S32>();
393 for (uint32_t i = 0; i < size; i++)
395 nhwc_rindices->at<loco::DataType::S32>(i) =
396 nchw_axis_to_nhwc(rindices->at<loco::DataType::S32>(i));
399 return nhwc_rindices;
402 luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant)
405 assert(constant->rank() == 4);
407 // TODO: Support non-float types
408 if (constant->dtype() != loco::DataType::FLOAT32)
410 INFO(l) << "Non-float type constant: " << constant->name() << std::endl;
414 loco::TensorShape nchw_dimension{constant->dim(0), constant->dim(1), constant->dim(2),
416 loco::TensorShape nhwc_dimension{constant->dim(0), constant->dim(2), constant->dim(3),
419 auto name = constant->name();
420 assert(name.length() > 0);
422 auto nhwc_const = constant->graph()->nodes()->create<luci::CircleConst>();
423 nhwc_const->dtype(constant->dtype());
425 nhwc_const->dim(0).set(constant->dim(0).value());
426 nhwc_const->dim(1).set(constant->dim(2).value());
427 nhwc_const->dim(2).set(constant->dim(3).value());
428 nhwc_const->dim(3).set(constant->dim(1).value());
429 nhwc_const->shape_status(luci::ShapeStatus::VALID);
430 nhwc_const->size<loco::DataType::FLOAT32>(constant->size<loco::DataType::FLOAT32>());
431 nhwc_const->name(name + "_NHWC");
433 for (uint32_t n = 0; n < nchw_dimension.dim(0).value(); n++)
435 for (uint32_t c = 0; c < nchw_dimension.dim(1).value(); c++)
437 for (uint32_t h = 0; h < nchw_dimension.dim(2).value(); h++)
439 for (uint32_t w = 0; w < nchw_dimension.dim(3).value(); w++)
441 uint32_t nchw_indices[4] = {n, c, h, w};
442 uint32_t nhwc_indices[4] = {n, h, w, c};
444 constant->at<loco::DataType::FLOAT32>(cal_offset(nchw_dimension, nchw_indices));
445 nhwc_const->at<loco::DataType::FLOAT32>(cal_offset(nhwc_dimension, nhwc_indices)) = data;
453 // NOTE Following conditions can be extended later
455 // Find PAD with an NCHW pattern described below
456 // - Paddings shape : [4, 2]
457 // - Paddings value : [[0, 0], [0, 0], [h_t, h_b], [w_t, w_b]]]
458 bool is_NCHW(const luci::CirclePad *node)
460 const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
461 // Non-const paddings is not supported
462 if (paddings == nullptr)
465 if (paddings->rank() != 2)
468 if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
471 // Only check the first two dimensions
472 for (uint32_t dim = 0; dim < 2; dim++)
474 for (uint32_t i = 0; i < 2; i++)
476 auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
485 // NOTE Copied from is_NCHW(CirclePad)
486 bool is_NCHW(const luci::CirclePadV2 *node)
488 const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
489 // Non-const paddings is not supported
490 if (paddings == nullptr)
493 if (paddings->rank() != 2)
496 if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
499 // Only check the first two dimensions
500 for (uint32_t dim = 0; dim < 2; dim++)
502 for (uint32_t i = 0; i < 2; i++)
504 auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
513 bool is_const(const loco::Node *node)
515 if (not dynamic_cast<const luci::CircleConst *>(node))
521 bool is_scalar_const(const loco::Node *node)
523 auto const_node = dynamic_cast<const luci::CircleConst *>(node);
527 const auto const_rank = const_node->rank();
530 // 2. rank = 1, dimension = 1
534 if (const_rank == 1 && const_node->dim(0).value() == 1)
540 // NOTE Following conditions can be extended later
542 // Find MUL with an NCHW pattern described below
543 // - Input (non-constant) shape : [N, C, H, W]
544 // - Input (constant) shape : broadcastable to [N, C, H, W]
545 // - Output shape : [N, C, H, W]
546 bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
547 luci::CircleConst *&multiplier)
549 auto x = dynamic_cast<luci::CircleConst *>(node->x());
550 auto y = dynamic_cast<luci::CircleConst *>(node->y());
552 if (x != nullptr && y == nullptr)
554 pred_node = loco::must_cast<luci::CircleNode *>(node->y());
557 else if (x == nullptr && y != nullptr)
559 pred_node = loco::must_cast<luci::CircleNode *>(node->x());
564 // Ignore if MUL does not have a multiplier input.
568 if (pred_node->rank() != 4)
571 if (not broadcastable(multiplier, node))
574 multiplier = expand_to_rank_4(multiplier);
579 // We assume ADD with const input is NCHW if,
580 // Input shape: (N, C, H, W)
581 // Output shape: (N, C, H, W)
582 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
583 // 2. Input, Output, Const have the same C.
584 bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node,
585 luci::CircleConst *&beta)
587 auto x = dynamic_cast<luci::CircleConst *>(node->x());
588 auto y = dynamic_cast<luci::CircleConst *>(node->y());
590 if (x != nullptr && y == nullptr)
592 pred_node = loco::must_cast<luci::CircleNode *>(node->y());
595 else if (x == nullptr && y != nullptr)
597 pred_node = loco::must_cast<luci::CircleNode *>(node->x());
602 // Ignore if ADD does not have a constant input.
606 if (pred_node->rank() != 4)
609 if (not broadcastable(beta, node))
612 beta = expand_to_rank_4(beta);
617 // We assume SUB with const input is NCHW if,
618 // Input shape: (N, C, H, W)
619 // Output shape: (N, C, H, W)
620 // 1. Const shape is (1, C, 1, 1), (N, C, H, W) or a scalar (1)
621 // 2. Input, Output, Const have the same C.
622 bool is_NCHW_with_const(const luci::CircleSub *node, const luci::CircleNode *pred_node,
623 const luci::CircleConst *subtract)
625 assert(pred_node != nullptr);
626 assert(subtract != nullptr);
628 if (pred_node->rank() != 4)
631 const auto const_rank = subtract->rank();
632 // Support Rank 4 or scalar (rank 0 or 1)
633 if (const_rank != 4 && const_rank != 0 && const_rank != 1)
636 const auto input_cdim = pred_node->dim(1);
637 const auto output_cdim = node->dim(1);
641 bool supported_shape = false;
643 // Check subtract is (1, C, 1, 1)
644 if (is_same_shape(subtract, {1, node->dim(1), 1, 1}))
645 supported_shape = true;
647 // Check subtract is (N, C, H, W)
648 if (is_same_shape(subtract, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
649 supported_shape = true;
651 return supported_shape;
653 if (input_cdim == output_cdim)
659 template <class T> bool convert_unary_features(T *node)
661 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->features());
662 auto pre_trans = create_pre_transpose(node);
663 pre_trans->a(pred_node);
664 node->features(pre_trans);
666 // Do shape inference for this node again.
667 node->shape_status(luci::ShapeStatus::UNDEFINED);
669 auto post_trans = create_post_transpose(node);
670 loco::replace(node).with(post_trans);
677 template <class T> bool convert_unary_x(T *node)
679 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
680 auto pre_trans = create_pre_transpose(node);
681 pre_trans->a(pred_node);
684 // Do shape inference for this node again.
685 node->shape_status(luci::ShapeStatus::UNDEFINED);
687 auto post_trans = create_post_transpose(node);
688 loco::replace(node).with(post_trans);
695 template <class T> bool convert_unary_logits(T *node)
697 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->logits());
698 auto pre_trans = create_pre_transpose(node);
699 pre_trans->a(pred_node);
700 node->logits(pre_trans);
702 // Do shape inference for this node again.
703 node->shape_status(luci::ShapeStatus::UNDEFINED);
705 auto post_trans = create_post_transpose(node);
706 loco::replace(node).with(post_trans);
713 class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
716 bool visit(luci::CircleNode *node)
718 throw std::runtime_error(node->name() + " is an unsupported operator.");
721 bool visit(luci::CircleInput *node)
723 const auto n = node->dim(0);
724 const auto c = node->dim(1);
725 const auto h = node->dim(2);
726 const auto w = node->dim(3);
732 // Do shape inference for this node again.
733 node->shape_status(luci::ShapeStatus::UNDEFINED);
735 // Insert post-tranpose
736 auto post_trans = create_post_transpose(node);
737 loco::replace(node).with(post_trans);
741 // Update graph input
742 auto graph_inputs = node->graph()->inputs();
743 auto graph_input = graph_inputs->at(node->index());
744 graph_input->shape({n, h, w, c});
749 bool visit(luci::CircleOutput *node)
751 // Insert pre-transpose
752 auto pre_trans = create_pre_transpose(node);
753 pre_trans->a(node->from());
755 node->from(pre_trans);
757 // Do shape inference for this node again.
758 node->shape_status(luci::ShapeStatus::UNDEFINED);
760 // Update graph output
761 const auto n = node->dim(0).value();
762 const auto c = node->dim(1).value();
763 const auto h = node->dim(2).value();
764 const auto w = node->dim(3).value();
766 auto graph_outputs = node->graph()->outputs();
767 auto graph_output = graph_outputs->at(node->index());
768 graph_output->shape({n, h, w, c});
773 bool visit(luci::CircleAdd *node)
775 luci::CircleNode *pred_node = nullptr;
776 luci::CircleConst *beta = nullptr;
778 if (is_NCHW_with_const(node, pred_node, beta))
780 assert(beta->rank() == 4); // FIX is_NCHW_with_const unless
781 auto nhwc_const = create_NHWC_from_NCHW(beta);
782 if (nhwc_const == nullptr)
786 auto pre_trans = create_pre_transpose(node);
787 pre_trans->a(pred_node);
790 else if (beta == nullptr)
792 // Both inputs are not constant.
793 // In this case, we cannot distinguish NCHW from NHWC,
794 // so just insert Transpose Ops.
795 auto pre_trans_x = create_pre_transpose(node);
796 pre_trans_x->a(node->x());
797 node->x(pre_trans_x);
799 auto pre_trans_y = create_pre_transpose(node);
800 pre_trans_y->a(node->y());
801 node->y(pre_trans_y);
808 // Do shape inference for this node again.
809 node->shape_status(luci::ShapeStatus::UNDEFINED);
811 auto post_trans = create_post_transpose(node);
812 loco::replace(node).with(post_trans);
818 bool visit(luci::CircleConcatenation *node)
820 const auto num_values = node->numValues();
821 for (uint32_t i = 0; i < num_values; i++)
823 auto pred_node = loco::must_cast<luci::CircleNode *>(node->values(i));
824 auto pre_trans = create_pre_transpose(node);
825 pre_trans->a(pred_node);
826 node->values(i, pre_trans);
829 // Do shape inference for this node again.
830 node->shape_status(luci::ShapeStatus::UNDEFINED);
832 node->axis(nchw_axis_to_nhwc(node->axis()));
834 auto post_trans = create_post_transpose(node);
835 loco::replace(node).with(post_trans);
842 bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }
844 bool visit(luci::CircleGelu *node) { return convert_unary_features<luci::CircleGelu>(node); }
846 bool visit(luci::CircleLeakyRelu *node)
848 return convert_unary_features<luci::CircleLeakyRelu>(node);
851 bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
853 bool visit(luci::CircleMaximum *node)
855 if ((not is_const(node->x())) and is_scalar_const(node->y()))
857 auto pre_trans = create_pre_transpose(node);
858 pre_trans->a(node->x());
861 else if (is_scalar_const(node->x()) and (not is_const(node->y())))
863 auto pre_trans = create_pre_transpose(node);
864 pre_trans->a(node->y());
867 else if ((not is_const(node->x())) and (not is_const(node->y())))
869 auto pre_trans_x = create_pre_transpose(node);
870 pre_trans_x->a(node->x());
871 node->x(pre_trans_x);
873 auto pre_trans_y = create_pre_transpose(node);
874 pre_trans_y->a(node->y());
875 node->y(pre_trans_y);
879 // TODO support other cases
883 // Do shape inference for this node again.
884 node->shape_status(luci::ShapeStatus::UNDEFINED);
886 auto post_trans = create_post_transpose(node);
887 loco::replace(node).with(post_trans);
893 bool visit(luci::CircleMean *node)
895 auto input = loco::must_cast<luci::CircleNode *>(node->input());
896 if (input->rank() != 4)
899 auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
903 auto nhwc_rindices = create_NHWC_rindices(rindices);
904 if (not nhwc_rindices)
907 auto pre_trans = create_pre_transpose(node);
909 node->input(pre_trans);
911 // Do shape inference for this node again.
912 node->shape_status(luci::ShapeStatus::UNDEFINED);
914 node->reduction_indices(nhwc_rindices);
916 if (node->keep_dims())
918 auto post_trans = create_post_transpose(node);
919 loco::replace(node).with(post_trans);
926 // node->keep_dims() == false
927 // 1D output never needs a transpose
928 if (node->rank() <= 1)
931 std::vector<bool> reduced_dims_nhwc(4, false);
932 uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
934 for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
936 reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
939 // if channel dimension has been reduced, we don't need a transpose
940 if (reduced_dims_nhwc[3])
943 // likewise, if both space dimensions are reduced, no transpose is needed
944 if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
947 std::vector<int32_t> post_trans_ind;
948 // case 1: only N is reduced
949 if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
950 post_trans_ind = {2, 0, 1};
952 // case 2: only H or W is reduced
953 if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
954 post_trans_ind = {0, 2, 1};
956 // case 3: N and either H or W are reduced
957 if (num_reduced_indices == 2)
958 post_trans_ind = {1, 0};
960 auto post_trans = create_Nd_transpose(node, post_trans_ind);
961 loco::replace(node).with(post_trans);
968 bool visit(luci::CircleMinimum *node)
970 if ((not is_const(node->x())) and is_scalar_const(node->y()))
972 auto pre_trans = create_pre_transpose(node);
973 pre_trans->a(node->x());
976 else if (is_scalar_const(node->x()) and (not is_const(node->y())))
978 auto pre_trans = create_pre_transpose(node);
979 pre_trans->a(node->y());
984 // TODO support other cases
988 // Do shape inference for this node again.
989 node->shape_status(luci::ShapeStatus::UNDEFINED);
991 auto post_trans = create_post_transpose(node);
992 loco::replace(node).with(post_trans);
998 bool visit(luci::CircleMul *node)
1002 luci::CircleNode *pred_node = nullptr;
1003 luci::CircleConst *multiplier = nullptr;
1005 if (is_NCHW_with_const(node, pred_node, multiplier))
1007 assert(multiplier->rank() == 4); // FIX is_NCHW_with_const unless
1008 auto nhwc_const = create_NHWC_from_NCHW(multiplier);
1009 if (nhwc_const == nullptr)
1011 node->y(nhwc_const);
1013 auto pre_trans = create_pre_transpose(node);
1014 pre_trans->a(pred_node);
1017 else if (multiplier == nullptr)
1019 // Only support for input rank 4
1020 auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1021 if (input_x->rank() != 4)
1023 auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1024 if (input_y->rank() != 4)
1027 auto pre_trans_x = create_pre_transpose(node);
1028 pre_trans_x->a(input_x);
1029 node->x(pre_trans_x);
1031 auto pre_trans_y = create_pre_transpose(node);
1032 pre_trans_y->a(input_y);
1033 node->y(pre_trans_y);
1040 // Do shape inference for this node again.
1041 node->shape_status(luci::ShapeStatus::UNDEFINED);
1043 auto post_trans = create_post_transpose(node);
1044 loco::replace(node).with(post_trans);
1046 post_trans->a(node);
1050 bool visit(luci::CircleNeg *node) { return convert_unary_x<luci::CircleNeg>(node); }
1052 bool visit(luci::CirclePad *node)
1057 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1058 auto pre_trans = create_pre_transpose(node);
1059 pre_trans->a(pred_node);
1060 node->input(pre_trans);
1062 auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1063 const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1064 node->paddings(nhwc_paddings);
1066 // Do shape inference for this node again.
1067 node->shape_status(luci::ShapeStatus::UNDEFINED);
1069 auto post_trans = create_post_transpose(node);
1070 loco::replace(node).with(post_trans);
1072 post_trans->a(node);
1077 bool visit(luci::CirclePadV2 *node)
1082 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1083 auto pre_trans = create_pre_transpose(node);
1084 pre_trans->a(pred_node);
1085 node->input(pre_trans);
1087 auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1088 const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
1089 node->paddings(nhwc_paddings);
1091 // Do shape inference for this node again.
1092 node->shape_status(luci::ShapeStatus::UNDEFINED);
1094 auto post_trans = create_post_transpose(node);
1095 loco::replace(node).with(post_trans);
1097 post_trans->a(node);
1102 // TODO Reduce duplicate code with CircleMean
1103 bool visit(luci::CircleReduceMax *node)
1105 auto input = loco::must_cast<luci::CircleNode *>(node->input());
1106 if (input->rank() != 4)
1109 auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
1113 auto nhwc_rindices = create_NHWC_rindices(rindices);
1114 if (not nhwc_rindices)
1117 auto pre_trans = create_pre_transpose(node);
1118 pre_trans->a(input);
1119 node->input(pre_trans);
1121 // Do shape inference for this node again.
1122 node->shape_status(luci::ShapeStatus::UNDEFINED);
1124 node->reduction_indices(nhwc_rindices);
1126 if (node->keep_dims())
1128 auto post_trans = create_post_transpose(node);
1129 loco::replace(node).with(post_trans);
1131 post_trans->a(node);
1136 // The below codes handle the cases where node->keep_dims() == false
1137 // 1D output never needs a transpose
1138 if (node->rank() <= 1)
1141 std::vector<bool> reduced_dims_nhwc(4, false);
1142 uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
1144 for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
1146 reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
1149 // if channel dimension has been reduced, we don't need a transpose
1150 if (reduced_dims_nhwc[3])
1153 // likewise, if both space dimensions are reduced, no transpose is needed
1154 if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
1157 std::vector<int32_t> post_trans_ind;
1158 // case 1: only N is reduced
1159 if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
1160 post_trans_ind = {2, 0, 1};
1162 // case 2: only H or W is reduced
1163 if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
1164 post_trans_ind = {0, 2, 1};
1166 // case 3: N and either H or W are reduced
1167 if (num_reduced_indices == 2)
1168 post_trans_ind = {1, 0};
1170 auto post_trans = create_Nd_transpose(node, post_trans_ind);
1171 loco::replace(node).with(post_trans);
1173 post_trans->a(node);
1178 // TODO Reduce duplicate codes with CircleReduceMax
1179 bool visit(luci::CircleReduceMin *node)
1181 auto input = loco::must_cast<luci::CircleNode *>(node->input());
1182 if (input->rank() != 4)
1185 auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
1189 auto nhwc_rindices = create_NHWC_rindices(rindices);
1190 if (not nhwc_rindices)
1193 auto pre_trans = create_pre_transpose(node);
1194 pre_trans->a(input);
1195 node->input(pre_trans);
1197 // Do shape inference for this node again.
1198 node->shape_status(luci::ShapeStatus::UNDEFINED);
1200 node->reduction_indices(nhwc_rindices);
1202 if (node->keep_dims())
1204 auto post_trans = create_post_transpose(node);
1205 loco::replace(node).with(post_trans);
1207 post_trans->a(node);
1212 // The below codes handle the cases where node->keep_dims() == false
1213 // 1D output never needs a transpose
1214 if (node->rank() <= 1)
1217 std::vector<bool> reduced_dims_nhwc(4, false);
1218 uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
1220 for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
1222 reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
1225 // if channel dimension has been reduced, we don't need a transpose
1226 if (reduced_dims_nhwc[3])
1229 // likewise, if both space dimensions are reduced, no transpose is needed
1230 if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
1233 std::vector<int32_t> post_trans_ind;
1234 // case 1: only N is reduced
1235 if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
1236 post_trans_ind = {2, 0, 1};
1238 // case 2: only H or W is reduced
1239 if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
1240 post_trans_ind = {0, 2, 1};
1242 // case 3: N and either H or W are reduced
1243 if (num_reduced_indices == 2)
1244 post_trans_ind = {1, 0};
1246 auto post_trans = create_Nd_transpose(node, post_trans_ind);
1247 loco::replace(node).with(post_trans);
1249 post_trans->a(node);
1254 bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
1256 bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
1258 bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
1260 bool visit(luci::CircleSplitV *node)
1262 // Change split dimension
1263 auto axis = dynamic_cast<luci::CircleConst *>(node->split_dim());
1267 if (axis->dtype() != loco::DataType::S32)
1270 if (axis->size<loco::DataType::S32>() != 1)
1273 axis->at<loco::DataType::S32>(0) = nchw_axis_to_nhwc(axis->at<loco::DataType::S32>(0));
1275 // Insert pre-transpose
1276 const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
1277 auto pre_trans = create_pre_transpose(node);
1278 pre_trans->a(pred_node);
1279 node->input(pre_trans);
1281 // Do shape inference for this node again.
1282 node->shape_status(luci::ShapeStatus::UNDEFINED);
1284 // Insert post-transposes
1285 for (auto succ : loco::succs(node))
1287 auto svo = loco::must_cast<luci::CircleSplitVOut *>(succ);
1289 auto post_trans = create_post_transpose(svo);
1290 loco::replace(svo).with(post_trans);
1297 bool visit(luci::CircleSquaredDifference *node)
1299 // TODO support CircleConst input
1300 if (dynamic_cast<luci::CircleConst *>(node->x()) != nullptr)
1302 if (dynamic_cast<luci::CircleConst *>(node->y()) != nullptr)
1305 auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1306 if (input_x->rank() != 4)
1308 auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1309 if (input_y->rank() != 4)
1312 auto pre_trans_x = create_pre_transpose(node);
1313 pre_trans_x->a(input_x);
1314 node->x(pre_trans_x);
1316 auto pre_trans_y = create_pre_transpose(node);
1317 pre_trans_y->a(input_y);
1318 node->y(pre_trans_y);
1320 // Do shape inference for this node again.
1321 node->shape_status(luci::ShapeStatus::UNDEFINED);
1323 auto post_trans = create_post_transpose(node);
1324 loco::replace(node).with(post_trans);
1326 post_trans->a(node);
1330 bool visit(luci::CircleSub *node)
1332 luci::CircleNode *pred_node = nullptr;
1333 luci::CircleConst *subtract = nullptr;
1335 auto const_x = dynamic_cast<luci::CircleConst *>(node->x());
1336 auto const_y = dynamic_cast<luci::CircleConst *>(node->y());
1338 if (const_x != nullptr && const_y == nullptr)
1340 // case of subtract - pred_node
1341 pred_node = loco::must_cast<luci::CircleNode *>(node->y());
1344 if (!is_NCHW_with_const(node, pred_node, subtract))
1347 auto pre_trans = create_pre_transpose(node);
1348 pre_trans->a(pred_node);
1350 if (subtract->rank() == 4)
1352 auto nhwc_const = create_NHWC_from_NCHW(subtract);
1353 if (nhwc_const == nullptr)
1355 node->x(nhwc_const);
1359 else if (const_x == nullptr && const_y != nullptr)
1361 // case of pred_node - subtract
1362 pred_node = loco::must_cast<luci::CircleNode *>(node->x());
1365 if (!is_NCHW_with_const(node, pred_node, subtract))
1368 auto pre_trans = create_pre_transpose(node);
1369 pre_trans->a(pred_node);
1371 if (subtract->rank() == 4)
1373 auto nhwc_const = create_NHWC_from_NCHW(subtract);
1374 if (nhwc_const == nullptr)
1376 node->y(nhwc_const);
1381 else if (const_x == nullptr && const_y == nullptr)
1383 // Both inputs are not constant.
1384 // In this case, we cannot distinguish NCHW from NHWC,
1385 // so just insert Transpose Ops.
1386 // Only support for input rank 4.
1387 auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
1388 if (input_x->rank() != 4)
1390 auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
1391 if (input_y->rank() != 4)
1394 auto pre_trans_x = create_pre_transpose(node);
1395 pre_trans_x->a(input_x);
1396 node->x(pre_trans_x);
1398 auto pre_trans_y = create_pre_transpose(node);
1399 pre_trans_y->a(input_y);
1400 node->y(pre_trans_y);
1403 // Do shape inference for this node again.
1404 node->shape_status(luci::ShapeStatus::UNDEFINED);
1406 auto post_trans = create_post_transpose(node);
1407 loco::replace(node).with(post_trans);
1409 post_trans->a(node);
1419 bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
1422 INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl;
1424 // Annotate NHWC operators
1425 // NHWC operators are detected by pattern matching
1428 // pre-Transose (or pre-Reshape) + [intermediate Ops] + post-Transpose (or post-Reshape)
1430 // [intermediate Ops] are annotated as NHWC
1432 // NOTE A single pre-Transpose/Reshape can have multiple post-Transpose/Reshape.
1434 // pre-Transpose --- [intermediate Ops] --- post-Transpose
1436 // +--[intermediate Ops] --- post-Transpose
1438 // NOTE Intermediate Ops SHOULD NOT contain pre-Transpose/Reshape
1439 for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
1441 if (has_data_format(node))
1444 if (is_pre_transpose(node) || is_pre_reshape(node))
1446 std::set<loco::Node *> intermediate;
1448 // Variable to check intermediate Ops contain pre-Transpose/Reshape
1449 bool has_pre = false;
1451 // Variable to check the pattern is closed with post-Transpose/Reshape
1452 bool is_closed = true;
1454 // For recursive call of lambda
1455 std::function<void(loco::Node *)> collect_intermediate;
1456 collect_intermediate = [&](loco::Node *n) {
1457 for (auto succ : loco::succs(n))
1459 // Skip unnecessary traversal
1460 if (intermediate.find(succ) != intermediate.end())
1464 if (is_post_transpose(succ) || is_post_reshape(succ))
1467 if (is_pre_transpose(succ) || is_pre_reshape(succ))
1473 if (is_output(succ))
1479 intermediate.emplace(succ);
1481 collect_intermediate(succ);
1485 collect_intermediate(node);
1487 if (has_pre or not is_closed)
1490 for (auto inter : intermediate)
1492 if (not has_data_format(inter))
1493 set_data_format(inter, DataFormat::NHWC);
1498 // Annotate NCHW operators
1499 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1501 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1502 switch (circle_node->opcode())
1504 // List of supported Ops
1505 case luci::CircleOpcode::CIRCLEINPUT:
1506 if (!_preserve_input && !has_data_format(node))
1508 set_data_format(node, DataFormat::NCHW);
1511 case luci::CircleOpcode::CIRCLEOUTPUT:
1512 if (!_preserve_output && !has_data_format(node))
1514 set_data_format(node, DataFormat::NCHW);
1517 // SOFTMAX, LOG_SOFTMAX are not converted, because
1518 // tflite/circle assumes the last channel is always axis
1519 case luci::CircleOpcode::ADD:
1520 case luci::CircleOpcode::CONCATENATION:
1521 case luci::CircleOpcode::ELU:
1522 case luci::CircleOpcode::GELU:
1523 case luci::CircleOpcode::LEAKY_RELU:
1524 case luci::CircleOpcode::LOGISTIC:
1525 case luci::CircleOpcode::MAXIMUM:
1526 case luci::CircleOpcode::MEAN:
1527 case luci::CircleOpcode::MINIMUM:
1528 case luci::CircleOpcode::MUL:
1529 case luci::CircleOpcode::NEG:
1530 case luci::CircleOpcode::PAD:
1531 case luci::CircleOpcode::PADV2:
1532 case luci::CircleOpcode::REDUCE_MAX:
1533 case luci::CircleOpcode::REDUCE_MIN:
1534 case luci::CircleOpcode::RELU:
1535 case luci::CircleOpcode::RELU6:
1536 case luci::CircleOpcode::RSQRT:
1537 case luci::CircleOpcode::SPLIT_V:
1538 case luci::CircleOpcode::SQUARED_DIFFERENCE:
1539 case luci::CircleOpcode::SUB:
1540 if (!has_data_format(node))
1542 set_data_format(node, DataFormat::NCHW);
1550 bool changed = false;
1551 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1553 if (!has_data_format(node))
1558 else if (get_data_format(node) == DataFormat::NHWC)
1560 // Already converted to NHWC
1563 else if (has_dynamic_shape(node))
1565 // This pass only works for static-shaped node
1566 INFO(l) << "Skip the node with a dynamic shape." << std::endl;
1571 ConvertNCHWToNHWC converter;
1572 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1573 if (circle_node->rank() != 4)
1575 // TODO replace the check above with the input rank check, and remove the condition below
1576 if (not dynamic_cast<luci::CircleMean *>(node) and
1577 not dynamic_cast<luci::CircleReduceMax *>(node) and
1578 not dynamic_cast<luci::CircleReduceMin *>(node))
1582 if (circle_node->accept(&converter))
1584 set_data_format(node, DataFormat::NHWC);
1594 INFO(l) << "ConvertNCHWToNHWCPass End" << std::endl;