2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <logo/Phase.h>
19 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
20 #include "luci/Pass/CircleShapeInferencePass.h"
22 #include <luci/IR/CircleNodes.h>
24 #include <gtest/gtest.h>
30 * Graph with a single Op (example: Add).
33 * - All Ops including Input/Output are NCHW.
42 * - All Ops including Input/Output are NHWC.
61 SimpleGraph() = default;
66 input = g.nodes()->create<luci::CircleInput>();
67 output = g.nodes()->create<luci::CircleOutput>();
69 output->name("output");
71 auto graph_input = g.inputs()->create();
72 input->index(graph_input->index());
73 auto graph_output = g.outputs()->create();
74 output->index(graph_output->index());
76 graph_input->dtype(loco::DataType::FLOAT32);
77 input->dtype(loco::DataType::FLOAT32);
78 output->dtype(loco::DataType::FLOAT32);
79 graph_output->dtype(loco::DataType::FLOAT32);
81 uint32_t channel_size = 16;
82 graph_input->shape({1, channel_size, 4, 4});
83 input->shape({1, channel_size, 4, 4});
84 output->shape({1, channel_size, 4, 4});
85 graph_output->shape({1, channel_size, 4, 4});
87 auto graph_body = insertGraphBody(input);
88 output->from(graph_body);
91 virtual ~SimpleGraph() = default;
94 virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
98 luci::CircleInput *input = nullptr;
99 luci::CircleOutput *output = nullptr;
102 class AddGraph final : public SimpleGraph
105 loco::Node *insertGraphBody(loco::Node *input) override
107 add = g.nodes()->create<luci::CircleAdd>();
108 beta = g.nodes()->create<luci::CircleConst>();
110 add->dtype(loco::DataType::FLOAT32);
111 beta->dtype(loco::DataType::FLOAT32);
113 uint32_t channel_size = 16;
114 add->shape({1, channel_size, 4, 4});
115 beta->shape({1, channel_size, 1, 1});
117 beta->size<loco::DataType::FLOAT32>(channel_size);
118 for (uint32_t i = 0; i < channel_size; i++)
120 beta->at<loco::DataType::FLOAT32>(i) = i;
133 void update_const_shape_to_nchw(void)
135 uint32_t channel_size = 16;
136 beta->shape({1, channel_size, 4, 4});
138 beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
139 for (uint32_t i = 0; i < channel_size; i++)
141 beta->at<loco::DataType::FLOAT32>(i) = i;
146 luci::CircleAdd *add = nullptr;
147 luci::CircleConst *beta = nullptr;
150 class NHWCReluGraph final : public SimpleGraph
153 loco::Node *insertGraphBody(loco::Node *input) override
155 relu = g.nodes()->create<luci::CircleRelu>();
156 pre_reshape = g.nodes()->create<luci::CircleReshape>();
157 post_reshape = g.nodes()->create<luci::CircleReshape>();
158 pre_shape = g.nodes()->create<luci::CircleConst>();
159 post_shape = g.nodes()->create<luci::CircleConst>();
161 pre_shape->dtype(loco::DataType::S32);
162 post_shape->dtype(loco::DataType::S32);
164 uint32_t channel_size = 16;
165 auto in = loco::must_cast<luci::CircleNode *>(input);
166 in->shape({1, channel_size, 4, 4});
167 pre_shape->shape({4});
168 post_shape->shape({4});
170 pre_shape->size<loco::DataType::S32>(4);
171 pre_shape->at<loco::DataType::S32>(0) = 1;
172 pre_shape->at<loco::DataType::S32>(1) = 4;
173 pre_shape->at<loco::DataType::S32>(2) = 4;
174 pre_shape->at<loco::DataType::S32>(3) = channel_size;
176 post_shape->size<loco::DataType::S32>(4);
177 post_shape->at<loco::DataType::S32>(0) = 1;
178 post_shape->at<loco::DataType::S32>(1) = channel_size;
179 post_shape->at<loco::DataType::S32>(2) = 4;
180 post_shape->at<loco::DataType::S32>(3) = 4;
182 pre_reshape->tensor(input);
183 pre_reshape->shape(pre_shape);
185 relu->features(pre_reshape);
187 post_reshape->tensor(relu);
188 post_reshape->shape(post_shape);
191 pre_reshape->name("pre-reshape");
192 post_reshape->name("post-reshape");
198 luci::CircleRelu *relu = nullptr;
199 luci::CircleReshape *pre_reshape = nullptr;
200 luci::CircleReshape *post_reshape = nullptr;
201 luci::CircleConst *pre_shape = nullptr;
202 luci::CircleConst *post_shape = nullptr;
205 class AddScalarGraph final : public SimpleGraph
208 loco::Node *insertGraphBody(loco::Node *input) override
210 add = g.nodes()->create<luci::CircleAdd>();
211 beta = g.nodes()->create<luci::CircleConst>();
213 add->dtype(loco::DataType::FLOAT32);
214 beta->dtype(loco::DataType::FLOAT32);
216 uint32_t channel_size = 16;
217 add->shape({1, channel_size, 4, 4});
220 beta->size<loco::DataType::FLOAT32>(1);
221 beta->at<loco::DataType::FLOAT32>(0) = 3.14;
233 luci::CircleAdd *add = nullptr;
234 luci::CircleConst *beta = nullptr;
237 class ConcatenationGraph final : public SimpleGraph
240 loco::Node *insertGraphBody(loco::Node *input) override
242 concat = g.nodes()->create<luci::CircleConcatenation>(2);
243 concat->values(0, input);
246 input2 = g.nodes()->create<luci::CircleConst>();
247 input2->dtype(loco::DataType::FLOAT32);
248 input2->shape({1, 16, 4, 4});
249 input2->size<loco::DataType::FLOAT32>(16 * 4 * 4);
250 for (uint32_t i = 0; i < 16 * 4 * 4; i++)
252 input2->at<loco::DataType::FLOAT32>(i) = i;
254 concat->values(1, input2);
256 concat->name("concat");
257 input2->name("input2");
263 luci::CircleConcatenation *concat = nullptr;
264 luci::CircleConst *input2 = nullptr;
267 class LeakyReluGraph final : public SimpleGraph
270 loco::Node *insertGraphBody(loco::Node *input) override
272 leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>();
273 leakyrelu->features(input);
274 leakyrelu->name("leakyrelu");
280 luci::CircleLeakyRelu *leakyrelu = nullptr;
283 class LogisticGraph final : public SimpleGraph
286 loco::Node *insertGraphBody(loco::Node *input) override
288 logistic = g.nodes()->create<luci::CircleLogistic>();
290 logistic->name("logistic");
296 luci::CircleLogistic *logistic = nullptr;
299 class MaximumGraph final : public SimpleGraph
302 loco::Node *insertGraphBody(loco::Node *input) override
304 max = g.nodes()->create<luci::CircleMaximum>();
305 limit = g.nodes()->create<luci::CircleConst>();
307 max->dtype(loco::DataType::FLOAT32);
308 limit->dtype(loco::DataType::FLOAT32);
310 max->shape({1, 16, 4, 4});
313 limit->size<loco::DataType::FLOAT32>(1);
314 limit->at<loco::DataType::FLOAT32>(0) = 100;
320 limit->name("limit");
326 luci::CircleMaximum *max = nullptr;
327 luci::CircleConst *limit = nullptr;
330 class MeanGraph final : public SimpleGraph
333 loco::Node *insertGraphBody(loco::Node *input) override
335 mean = g.nodes()->create<luci::CircleMean>();
336 rindices = g.nodes()->create<luci::CircleConst>();
338 mean->dtype(loco::DataType::FLOAT32);
339 rindices->dtype(loco::DataType::S32);
342 rindices->shape({static_cast<uint32_t>(_axes.size())});
344 rindices->size<loco::DataType::S32>(_axes.size());
345 for (uint32_t i = 0; i < _axes.size(); ++i)
347 rindices->at<loco::DataType::S32>(i) = _axes[i];
351 mean->reduction_indices(rindices);
352 mean->keep_dims(_keep_dims);
355 rindices->name("rindices");
361 void keep_dims(bool val) { _keep_dims = val; }
362 void axes(std::vector<int32_t> val) { _axes = val; }
363 void shape(std::initializer_list<uint32_t> val) { _shape = val; }
366 luci::CircleMean *mean = nullptr;
367 luci::CircleConst *rindices = nullptr;
370 bool _keep_dims = true;
371 std::vector<int32_t> _axes = {2, 3};
372 std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
375 class MinimumGraph final : public SimpleGraph
378 loco::Node *insertGraphBody(loco::Node *input) override
380 min = g.nodes()->create<luci::CircleMinimum>();
381 limit = g.nodes()->create<luci::CircleConst>();
383 min->dtype(loco::DataType::FLOAT32);
384 limit->dtype(loco::DataType::FLOAT32);
386 min->shape({1, 16, 4, 4});
389 limit->size<loco::DataType::FLOAT32>(1);
390 limit->at<loco::DataType::FLOAT32>(0) = 100;
396 limit->name("limit");
402 luci::CircleMinimum *min = nullptr;
403 luci::CircleConst *limit = nullptr;
406 class MulGraph final : public SimpleGraph
409 loco::Node *insertGraphBody(loco::Node *input) override
411 mul = g.nodes()->create<luci::CircleMul>();
412 multiplier = g.nodes()->create<luci::CircleConst>();
414 mul->dtype(loco::DataType::FLOAT32);
415 multiplier->dtype(loco::DataType::FLOAT32);
417 uint32_t channel_size = 16;
418 mul->shape({1, channel_size, 4, 4});
419 multiplier->shape({1, channel_size, 1, 1});
421 multiplier->size<loco::DataType::FLOAT32>(channel_size);
422 for (uint32_t i = 0; i < channel_size; i++)
424 multiplier->at<loco::DataType::FLOAT32>(i) = i;
431 multiplier->name("multiplier");
437 void update_const_shape_to_nchw(void)
439 uint32_t channel_size = 16;
440 multiplier->shape({1, channel_size, 4, 4});
442 multiplier->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
443 for (uint32_t i = 0; i < channel_size; i++)
445 multiplier->at<loco::DataType::FLOAT32>(i) = i;
450 luci::CircleMul *mul = nullptr;
451 luci::CircleConst *multiplier = nullptr;
454 class MulScalarGraph final : public SimpleGraph
457 loco::Node *insertGraphBody(loco::Node *input) override
459 mul = g.nodes()->create<luci::CircleMul>();
460 multiplier = g.nodes()->create<luci::CircleConst>();
462 mul->dtype(loco::DataType::FLOAT32);
463 multiplier->dtype(loco::DataType::FLOAT32);
465 uint32_t channel_size = 16;
466 mul->shape({1, channel_size, 4, 4});
467 multiplier->shape({1});
469 multiplier->size<loco::DataType::FLOAT32>(1);
470 multiplier->at<loco::DataType::FLOAT32>(0) = 2;
476 multiplier->name("multiplier");
482 luci::CircleMul *mul = nullptr;
483 luci::CircleConst *multiplier = nullptr;
486 class MulBothNormGraph final : public SimpleGraph
489 loco::Node *insertGraphBody(loco::Node *input) override
491 mul = g.nodes()->create<luci::CircleMul>();
493 mul->dtype(loco::DataType::FLOAT32);
495 uint32_t channel_size = 16;
496 mul->shape({1, channel_size, 4, 4});
507 luci::CircleMul *mul = nullptr;
510 class NegGraph final : public SimpleGraph
513 loco::Node *insertGraphBody(loco::Node *input) override
515 neg = g.nodes()->create<luci::CircleNeg>();
523 luci::CircleNeg *neg = nullptr;
526 class PadGraph final : public SimpleGraph
529 loco::Node *insertGraphBody(loco::Node *input) override
531 pad = g.nodes()->create<luci::CirclePad>();
532 paddings = g.nodes()->create<luci::CircleConst>();
534 pad->dtype(loco::DataType::FLOAT32);
535 paddings->dtype(loco::DataType::S32);
537 uint32_t channel_size = 16;
538 pad->shape({1, channel_size, 4, 4});
539 paddings->shape({4, 2});
541 // paddings data (NCHW)
542 // [[0,0], [0,0], [1,1], [2,2]]
543 paddings->size<loco::DataType::S32>(8);
544 for (uint32_t dim = 0; dim < 4; dim++)
546 for (uint32_t i = 0; i < 2; i++)
555 paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
560 pad->paddings(paddings);
563 paddings->name("paddings");
569 luci::CirclePad *pad = nullptr;
570 luci::CircleConst *paddings = nullptr;
573 class PadV2Graph final : public SimpleGraph
576 loco::Node *insertGraphBody(loco::Node *input) override
578 pad = g.nodes()->create<luci::CirclePadV2>();
579 paddings = g.nodes()->create<luci::CircleConst>();
580 const_value = g.nodes()->create<luci::CircleConst>();
582 pad->dtype(loco::DataType::FLOAT32);
583 paddings->dtype(loco::DataType::S32);
584 const_value->dtype(loco::DataType::FLOAT32);
586 uint32_t channel_size = 16;
587 pad->shape({1, channel_size, 4, 4});
588 paddings->shape({4, 2});
589 const_value->shape({1});
591 // paddings data (NCHW)
592 // [[0,0], [0,0], [1,1], [2,2]]
593 paddings->size<loco::DataType::S32>(8);
594 for (uint32_t dim = 0; dim < 4; dim++)
596 for (uint32_t i = 0; i < 2; i++)
605 paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
609 const_value->size<loco::DataType::FLOAT32>(1);
610 const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
613 pad->paddings(paddings);
614 pad->constant_values(paddings);
617 paddings->name("paddings");
618 const_value->name("constant_values");
624 luci::CirclePadV2 *pad = nullptr;
625 luci::CircleConst *paddings = nullptr;
626 luci::CircleConst *const_value = nullptr;
629 class ReluGraph final : public SimpleGraph
632 loco::Node *insertGraphBody(loco::Node *input) override
634 relu = g.nodes()->create<luci::CircleRelu>();
635 relu->features(input);
642 luci::CircleRelu *relu = nullptr;
645 class Relu6Graph final : public SimpleGraph
648 loco::Node *insertGraphBody(loco::Node *input) override
650 relu6 = g.nodes()->create<luci::CircleRelu6>();
651 relu6->features(input);
652 relu6->name("relu6");
658 luci::CircleRelu6 *relu6 = nullptr;
661 class RsqrtGraph final : public SimpleGraph
664 loco::Node *insertGraphBody(loco::Node *input) override
666 rsqrt = g.nodes()->create<luci::CircleRsqrt>();
668 rsqrt->name("rsqrt");
674 luci::CircleRsqrt *rsqrt = nullptr;
677 class SquaredDifferenceGraph final : public SimpleGraph
680 loco::Node *insertGraphBody(loco::Node *input) override
682 sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
685 sqdiff->name("sqdiff");
691 luci::CircleSquaredDifference *sqdiff = nullptr;
694 class SubGraph final : public SimpleGraph
697 loco::Node *insertGraphBody(loco::Node *input) override
699 sub = g.nodes()->create<luci::CircleSub>();
700 beta = g.nodes()->create<luci::CircleConst>();
702 sub->dtype(loco::DataType::FLOAT32);
703 beta->dtype(loco::DataType::FLOAT32);
705 uint32_t channel_size = 16;
706 sub->shape({1, channel_size, 4, 4});
707 beta->shape({1, channel_size, 1, 1});
709 beta->size<loco::DataType::FLOAT32>(channel_size);
710 for (uint32_t i = 0; i < channel_size; i++)
712 beta->at<loco::DataType::FLOAT32>(i) = i;
725 void update_const_shape_to_nchw(void)
727 uint32_t channel_size = 16;
728 beta->shape({1, channel_size, 4, 4});
730 beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
731 for (uint32_t i = 0; i < channel_size; i++)
733 beta->at<loco::DataType::FLOAT32>(i) = i;
738 luci::CircleSub *sub = nullptr;
739 luci::CircleConst *beta = nullptr;
742 class SubScalarGraph final : public SimpleGraph
745 loco::Node *insertGraphBody(loco::Node *input) override
747 sub = g.nodes()->create<luci::CircleSub>();
748 beta = g.nodes()->create<luci::CircleConst>();
750 sub->dtype(loco::DataType::FLOAT32);
751 beta->dtype(loco::DataType::FLOAT32);
753 uint32_t channel_size = 16;
754 sub->shape({1, channel_size, 4, 4});
757 beta->size<loco::DataType::FLOAT32>(1);
758 beta->at<loco::DataType::FLOAT32>(0) = 5;
770 luci::CircleSub *sub = nullptr;
771 luci::CircleConst *beta = nullptr;
774 void check_pre_trans(loco::Node *node)
776 auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
777 EXPECT_NE(nullptr, pre_trans);
778 auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm());
779 EXPECT_NE(nullptr, pre_trans_perm);
780 EXPECT_EQ(1, pre_trans_perm->rank());
781 EXPECT_EQ(4, pre_trans_perm->dim(0).value());
782 EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype());
783 EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0));
784 EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1));
785 EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2));
786 EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3));
789 void check_post_trans(loco::Node *node)
791 auto post_trans = dynamic_cast<luci::CircleTranspose *>(node);
792 EXPECT_NE(nullptr, post_trans);
793 auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm());
794 EXPECT_NE(nullptr, post_trans_perm);
795 EXPECT_EQ(1, post_trans_perm->rank());
796 EXPECT_EQ(4, post_trans_perm->dim(0).value());
797 EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype());
798 EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0));
799 EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1));
800 EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2));
801 EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3));
804 void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output)
809 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
813 std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
815 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
816 phase_runner.run(phase);
821 TEST(ConvertNCHWToNHWCPassTest, name)
823 luci::ConvertNCHWToNHWCPass pass(false, false);
824 auto const name = pass.name();
825 ASSERT_NE(nullptr, name);
828 TEST(ConvertNCHWToNHWC, Add)
833 run_phase(&g.g, false, false);
835 auto input_succs = loco::succs(g.input);
836 EXPECT_EQ(1, input_succs.size());
837 check_post_trans(*input_succs.begin());
839 check_pre_trans(g.add->x());
841 auto add_succs = loco::succs(g.add);
842 EXPECT_EQ(1, add_succs.size());
843 check_post_trans(*add_succs.begin());
845 uint32_t channel_size = 16;
846 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
847 EXPECT_NE(nullptr, new_beta);
848 EXPECT_EQ(4, new_beta->rank());
849 EXPECT_EQ(1, new_beta->dim(0).value());
850 EXPECT_EQ(1, new_beta->dim(1).value());
851 EXPECT_EQ(1, new_beta->dim(2).value());
852 EXPECT_EQ(channel_size, new_beta->dim(3).value());
854 check_pre_trans(g.output->from());
857 TEST(ConvertNCHWToNHWC, Add_NCHW_const)
861 g.update_const_shape_to_nchw();
863 run_phase(&g.g, false, false);
865 check_pre_trans(g.add->x());
867 auto add_succs = loco::succs(g.add);
868 EXPECT_EQ(1, add_succs.size());
869 check_post_trans(*add_succs.begin());
871 uint32_t channel_size = 16;
872 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
873 EXPECT_NE(nullptr, new_beta);
874 EXPECT_EQ(4, new_beta->rank());
875 EXPECT_EQ(1, new_beta->dim(0).value());
876 EXPECT_EQ(4, new_beta->dim(1).value());
877 EXPECT_EQ(4, new_beta->dim(2).value());
878 EXPECT_EQ(channel_size, new_beta->dim(3).value());
881 TEST(ConvertNCHWToNHWC, NHWC_Relu)
883 // Relu is already NHWC, so it should not be converted
884 // i.e., the graph is not changed
888 run_phase(&g.g, false, false);
890 EXPECT_EQ(g.pre_reshape, g.relu->features());
892 auto relu_succs = loco::succs(g.relu);
893 EXPECT_EQ(1, relu_succs.size());
894 EXPECT_EQ(g.post_reshape, *relu_succs.begin());
897 TEST(ConvertNCHWToNHWC, AddScalar)
902 run_phase(&g.g, false, false);
904 auto input_succs = loco::succs(g.input);
905 EXPECT_EQ(1, input_succs.size());
906 check_post_trans(*input_succs.begin());
908 check_pre_trans(g.add->x());
910 auto add_succs = loco::succs(g.add);
911 EXPECT_EQ(1, add_succs.size());
912 check_post_trans(*add_succs.begin());
914 auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
915 EXPECT_NE(nullptr, new_beta);
916 EXPECT_EQ(1, new_beta->rank());
917 EXPECT_EQ(1, new_beta->dim(0).value());
919 check_pre_trans(g.output->from());
922 TEST(ConvertNCHWToNHWC, Concatenation)
924 ConcatenationGraph g;
927 run_phase(&g.g, true, true);
929 check_pre_trans(g.concat->values(0));
930 check_pre_trans(g.concat->values(1));
932 auto concat_succs = loco::succs(g.concat);
933 EXPECT_EQ(1, concat_succs.size());
934 check_post_trans(*concat_succs.begin());
936 // Check concat shape, axis
937 EXPECT_EQ(1, g.concat->dim(0).value());
938 EXPECT_EQ(4, g.concat->dim(1).value());
939 EXPECT_EQ(4, g.concat->dim(2).value());
940 EXPECT_EQ(32, g.concat->dim(3).value());
941 EXPECT_EQ(3, g.concat->axis());
944 TEST(ConvertNCHWToNHWC, LeakyRelu)
949 run_phase(&g.g, true, true);
951 check_pre_trans(g.leakyrelu->features());
953 auto leakyrelu_succs = loco::succs(g.leakyrelu);
954 EXPECT_EQ(1, leakyrelu_succs.size());
955 check_post_trans(*leakyrelu_succs.begin());
957 // Check leakyrelu shape
958 EXPECT_EQ(1, g.leakyrelu->dim(0).value());
959 EXPECT_EQ(4, g.leakyrelu->dim(1).value());
960 EXPECT_EQ(4, g.leakyrelu->dim(2).value());
961 EXPECT_EQ(16, g.leakyrelu->dim(3).value());
964 TEST(ConvertNCHWToNHWC, Logistic)
969 run_phase(&g.g, true, true);
971 check_pre_trans(g.logistic->x());
973 auto logistic_succs = loco::succs(g.logistic);
974 EXPECT_EQ(1, logistic_succs.size());
975 check_post_trans(*logistic_succs.begin());
977 // Check logistic shape
978 EXPECT_EQ(1, g.logistic->dim(0).value());
979 EXPECT_EQ(4, g.logistic->dim(1).value());
980 EXPECT_EQ(4, g.logistic->dim(2).value());
981 EXPECT_EQ(16, g.logistic->dim(3).value());
984 TEST(ConvertNCHWToNHWC, Maximum)
989 run_phase(&g.g, false, false);
991 auto input_succs = loco::succs(g.input);
992 EXPECT_EQ(1, input_succs.size());
993 check_post_trans(*input_succs.begin());
995 check_pre_trans(g.max->x());
997 auto max_succs = loco::succs(g.max);
998 EXPECT_EQ(1, max_succs.size());
999 check_post_trans(*max_succs.begin());
1001 check_pre_trans(g.output->from());
1004 TEST(ConvertNCHWToNHWC, Mean)
1009 run_phase(&g.g, false, false);
1011 check_pre_trans(g.mean->input());
1013 auto mean_succs = loco::succs(g.mean);
1014 EXPECT_EQ(1, mean_succs.size());
1015 check_post_trans(*mean_succs.begin());
1017 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1018 EXPECT_NE(nullptr, new_rindices);
1019 EXPECT_EQ(1, new_rindices->rank());
1020 EXPECT_EQ(2, new_rindices->dim(0).value());
1021 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1022 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1023 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1026 TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
1030 std::vector<int32_t> nchw_ind;
1031 std::vector<int32_t> nhwc_ind;
1032 std::initializer_list<uint32_t> shape;
1033 bool needs_transpose = false;
1041 std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
1042 {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
1043 {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
1044 {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
1045 {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
1046 {{0, 1, 2}, {0, 3, 1}, {w}, false}};
1048 for (auto &tc : test_cases)
1052 g.axes(tc.nchw_ind);
1056 run_phase(&g.g, false, true);
1058 check_pre_trans(g.mean->input());
1060 auto mean_succs = loco::succs(g.mean);
1061 EXPECT_EQ(1, mean_succs.size());
1062 if (tc.needs_transpose)
1064 EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
1068 EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
1071 auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
1072 EXPECT_NE(nullptr, new_rindices);
1073 EXPECT_EQ(1, new_rindices->rank());
1074 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
1075 EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
1076 for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
1078 EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
1083 TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
1086 auto input = g.nodes()->create<luci::CircleInput>();
1087 auto output = g.nodes()->create<luci::CircleOutput>();
1088 input->name("input");
1089 output->name("output");
1091 auto graph_input = g.inputs()->create();
1092 input->index(graph_input->index());
1093 auto graph_output = g.outputs()->create();
1094 output->index(graph_output->index());
1096 graph_input->dtype(loco::DataType::FLOAT32);
1097 input->dtype(loco::DataType::FLOAT32);
1098 output->dtype(loco::DataType::FLOAT32);
1099 graph_output->dtype(loco::DataType::FLOAT32);
1101 uint32_t channel_size = 16;
1102 graph_input->shape({channel_size, 4, 4});
1103 input->shape({channel_size, 4, 4});
1104 output->shape({channel_size});
1105 graph_output->shape({channel_size});
1107 auto mean = g.nodes()->create<luci::CircleMean>();
1108 auto rindices = g.nodes()->create<luci::CircleConst>();
1110 mean->dtype(loco::DataType::FLOAT32);
1111 rindices->dtype(loco::DataType::S32);
1113 mean->shape({channel_size});
1114 rindices->shape({2});
1116 rindices->size<loco::DataType::S32>(2);
1117 rindices->at<loco::DataType::S32>(0) = 1;
1118 rindices->at<loco::DataType::S32>(1) = 2;
1121 mean->reduction_indices(rindices);
1122 mean->keep_dims(false);
1125 rindices->name("rindices");
1129 run_phase(&g, true, true);
1131 auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
1132 EXPECT_NE(nullptr, new_rindices);
1133 EXPECT_EQ(1, new_rindices->rank());
1134 EXPECT_EQ(2, new_rindices->dim(0).value());
1135 EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
1136 EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
1137 EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
1140 TEST(ConvertNCHWToNHWC, Minimum)
1145 run_phase(&g.g, false, false);
1147 auto input_succs = loco::succs(g.input);
1148 EXPECT_EQ(1, input_succs.size());
1149 check_post_trans(*input_succs.begin());
1151 check_pre_trans(g.min->x());
1153 auto min_succs = loco::succs(g.min);
1154 EXPECT_EQ(1, min_succs.size());
1155 check_post_trans(*min_succs.begin());
1157 check_pre_trans(g.output->from());
1160 TEST(ConvertNCHWToNHWC, Mul)
1165 run_phase(&g.g, false, false);
1167 auto input_succs = loco::succs(g.input);
1168 EXPECT_EQ(1, input_succs.size());
1169 check_post_trans(*input_succs.begin());
1171 check_pre_trans(g.mul->x());
1173 auto mul_succs = loco::succs(g.mul);
1174 EXPECT_EQ(1, mul_succs.size());
1175 check_post_trans(*mul_succs.begin());
1177 uint32_t channel_size = 16;
1178 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1179 EXPECT_NE(nullptr, new_multiplier);
1180 EXPECT_EQ(4, new_multiplier->rank());
1181 EXPECT_EQ(1, new_multiplier->dim(0).value());
1182 EXPECT_EQ(1, new_multiplier->dim(1).value());
1183 EXPECT_EQ(1, new_multiplier->dim(2).value());
1184 EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1186 check_pre_trans(g.output->from());
1189 TEST(ConvertNCHWToNHWC, Mul_NCHW_const)
1193 g.update_const_shape_to_nchw();
1195 run_phase(&g.g, false, false);
1197 check_pre_trans(g.mul->x());
1199 auto mul_succs = loco::succs(g.mul);
1200 EXPECT_EQ(1, mul_succs.size());
1201 check_post_trans(*mul_succs.begin());
1203 uint32_t channel_size = 16;
1204 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1205 EXPECT_NE(nullptr, new_multiplier);
1206 EXPECT_EQ(4, new_multiplier->rank());
1207 EXPECT_EQ(1, new_multiplier->dim(0).value());
1208 EXPECT_EQ(4, new_multiplier->dim(1).value());
1209 EXPECT_EQ(4, new_multiplier->dim(2).value());
1210 EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
1213 TEST(ConvertNCHWToNHWC, MulScalar)
1218 run_phase(&g.g, false, false);
1220 auto input_succs = loco::succs(g.input);
1221 EXPECT_EQ(1, input_succs.size());
1222 check_post_trans(*input_succs.begin());
1224 check_pre_trans(g.mul->x());
1226 auto mul_succs = loco::succs(g.mul);
1227 EXPECT_EQ(1, mul_succs.size());
1228 check_post_trans(*mul_succs.begin());
1230 auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
1231 EXPECT_NE(nullptr, new_multiplier);
1232 EXPECT_EQ(1, new_multiplier->rank());
1233 EXPECT_EQ(1, new_multiplier->dim(0).value());
1235 check_pre_trans(g.output->from());
1238 TEST(ConvertNCHWToNHWC, MulBothNorm)
1243 run_phase(&g.g, false, false);
1245 auto input_succs = loco::succs(g.input);
1246 EXPECT_EQ(1, input_succs.size());
1247 check_post_trans(*input_succs.begin());
1249 check_pre_trans(g.mul->x());
1250 check_pre_trans(g.mul->y());
1252 auto mul_succs = loco::succs(g.mul);
1253 EXPECT_EQ(1, mul_succs.size());
1254 check_post_trans(*mul_succs.begin());
1256 check_pre_trans(g.output->from());
1259 TEST(ConvertNCHWToNHWC, Neg)
1264 run_phase(&g.g, true, true);
1266 check_pre_trans(g.neg->x());
1268 auto neg_succs = loco::succs(g.neg);
1269 EXPECT_EQ(1, neg_succs.size());
1270 check_post_trans(*neg_succs.begin());
1272 // Check leakyrelu shape
1273 EXPECT_EQ(1, g.neg->dim(0).value());
1274 EXPECT_EQ(4, g.neg->dim(1).value());
1275 EXPECT_EQ(4, g.neg->dim(2).value());
1276 EXPECT_EQ(16, g.neg->dim(3).value());
1279 TEST(ConvertNCHWToNHWC, Pad)
1284 run_phase(&g.g, false, false);
1286 auto input_succs = loco::succs(g.input);
1287 EXPECT_EQ(1, input_succs.size());
1288 check_post_trans(*input_succs.begin());
1290 check_pre_trans(g.pad->input());
1292 auto pad_succs = loco::succs(g.pad);
1293 EXPECT_EQ(1, pad_succs.size());
1294 check_post_trans(*pad_succs.begin());
1296 auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1297 EXPECT_NE(nullptr, new_paddings);
1298 EXPECT_EQ(2, new_paddings->rank());
1299 EXPECT_EQ(4, new_paddings->dim(0).value());
1300 EXPECT_EQ(2, new_paddings->dim(1).value());
1301 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1302 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1303 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1304 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1305 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1306 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1307 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1308 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1310 check_pre_trans(g.output->from());
1313 TEST(ConvertNCHWToNHWC, PadV2)
1318 run_phase(&g.g, false, false);
1320 check_pre_trans(g.pad->input());
1322 auto pad_succs = loco::succs(g.pad);
1323 EXPECT_EQ(1, pad_succs.size());
1324 check_post_trans(*pad_succs.begin());
1326 auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
1327 EXPECT_NE(nullptr, new_paddings);
1328 EXPECT_EQ(2, new_paddings->rank());
1329 EXPECT_EQ(4, new_paddings->dim(0).value());
1330 EXPECT_EQ(2, new_paddings->dim(1).value());
1331 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
1332 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
1333 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
1334 EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
1335 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
1336 EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
1337 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
1338 EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
1341 TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
1347 g.input->dim(0).unset();
1348 g.add->dim(0).unset();
1349 g.output->dim(0).unset();
1351 luci::ConvertNCHWToNHWCPass pass(false, false);
1352 EXPECT_EQ(false, pass.run(&g.g));
1355 TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
1362 run_phase(&g.g, true, false);
1364 // Check input shape
1365 EXPECT_EQ(1, g.input->dim(0).value());
1366 EXPECT_EQ(16, g.input->dim(1).value());
1367 EXPECT_EQ(4, g.input->dim(2).value());
1368 EXPECT_EQ(4, g.input->dim(3).value());
1370 // Check output shape
1371 EXPECT_EQ(1, g.output->dim(0).value());
1372 EXPECT_EQ(4, g.output->dim(1).value());
1373 EXPECT_EQ(4, g.output->dim(2).value());
1374 EXPECT_EQ(16, g.output->dim(3).value());
1382 run_phase(&g.g, false, true);
1384 // Check input shape
1385 EXPECT_EQ(1, g.input->dim(0).value());
1386 EXPECT_EQ(4, g.input->dim(1).value());
1387 EXPECT_EQ(4, g.input->dim(2).value());
1388 EXPECT_EQ(16, g.input->dim(3).value());
1390 // Check output shape
1391 EXPECT_EQ(1, g.output->dim(0).value());
1392 EXPECT_EQ(16, g.output->dim(1).value());
1393 EXPECT_EQ(4, g.output->dim(2).value());
1394 EXPECT_EQ(4, g.output->dim(3).value());
1397 // Preserve both input and output
1402 run_phase(&g.g, true, true);
1404 // Check input shape
1405 EXPECT_EQ(1, g.input->dim(0).value());
1406 EXPECT_EQ(16, g.input->dim(1).value());
1407 EXPECT_EQ(4, g.input->dim(2).value());
1408 EXPECT_EQ(4, g.input->dim(3).value());
1410 // Check output shape
1411 EXPECT_EQ(1, g.output->dim(0).value());
1412 EXPECT_EQ(16, g.output->dim(1).value());
1413 EXPECT_EQ(4, g.output->dim(2).value());
1414 EXPECT_EQ(4, g.output->dim(3).value());
1418 TEST(ConvertNCHWToNHWC, Relu)
1423 run_phase(&g.g, true, true);
1425 check_pre_trans(g.relu->features());
1427 auto relu_succs = loco::succs(g.relu);
1428 EXPECT_EQ(1, relu_succs.size());
1429 check_post_trans(*relu_succs.begin());
1432 EXPECT_EQ(1, g.relu->dim(0).value());
1433 EXPECT_EQ(4, g.relu->dim(1).value());
1434 EXPECT_EQ(4, g.relu->dim(2).value());
1435 EXPECT_EQ(16, g.relu->dim(3).value());
1438 TEST(ConvertNCHWToNHWC, Relu6)
1443 run_phase(&g.g, true, true);
1445 check_pre_trans(g.relu6->features());
1447 auto relu6_succs = loco::succs(g.relu6);
1448 EXPECT_EQ(1, relu6_succs.size());
1449 check_post_trans(*relu6_succs.begin());
1451 // Check relu6 shape
1452 EXPECT_EQ(1, g.relu6->dim(0).value());
1453 EXPECT_EQ(4, g.relu6->dim(1).value());
1454 EXPECT_EQ(4, g.relu6->dim(2).value());
1455 EXPECT_EQ(16, g.relu6->dim(3).value());
1458 TEST(ConvertNCHWToNHWC, Rsqrt)
1463 run_phase(&g.g, true, true);
1465 check_pre_trans(g.rsqrt->x());
1467 auto rsqrt_succs = loco::succs(g.rsqrt);
1468 EXPECT_EQ(1, rsqrt_succs.size());
1469 check_post_trans(*rsqrt_succs.begin());
1471 // Check rsqrt shape
1472 EXPECT_EQ(1, g.rsqrt->dim(0).value());
1473 EXPECT_EQ(4, g.rsqrt->dim(1).value());
1474 EXPECT_EQ(4, g.rsqrt->dim(2).value());
1475 EXPECT_EQ(16, g.rsqrt->dim(3).value());
1478 TEST(ConvertNCHWToNHWC, SquaredDifference)
1480 SquaredDifferenceGraph g;
1483 run_phase(&g.g, true, true);
1485 check_pre_trans(g.sqdiff->x());
1486 check_pre_trans(g.sqdiff->y());
1488 auto sqdiff_succs = loco::succs(g.sqdiff);
1489 EXPECT_EQ(1, sqdiff_succs.size());
1490 check_post_trans(*sqdiff_succs.begin());
1493 TEST(ConvertNCHWToNHWC, Sub)
1498 run_phase(&g.g, false, false);
1500 auto input_succs = loco::succs(g.input);
1501 EXPECT_EQ(1, input_succs.size());
1502 check_post_trans(*input_succs.begin());
1504 check_pre_trans(g.sub->x());
1506 auto add_succs = loco::succs(g.sub);
1507 EXPECT_EQ(1, add_succs.size());
1508 check_post_trans(*add_succs.begin());
1510 uint32_t channel_size = 16;
1511 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
1512 EXPECT_NE(nullptr, new_beta);
1513 EXPECT_EQ(4, new_beta->rank());
1514 EXPECT_EQ(1, new_beta->dim(0).value());
1515 EXPECT_EQ(1, new_beta->dim(1).value());
1516 EXPECT_EQ(1, new_beta->dim(2).value());
1517 EXPECT_EQ(channel_size, new_beta->dim(3).value());
1519 check_pre_trans(g.output->from());
1522 TEST(ConvertNCHWToNHWC, Sub_NCHW_const)
1526 g.update_const_shape_to_nchw();
1528 run_phase(&g.g, false, false);
1530 check_pre_trans(g.sub->x());
1532 auto sub_succs = loco::succs(g.sub);
1533 EXPECT_EQ(1, sub_succs.size());
1534 check_post_trans(*sub_succs.begin());
1536 uint32_t channel_size = 16;
1537 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
1538 EXPECT_NE(nullptr, new_beta);
1539 EXPECT_EQ(4, new_beta->rank());
1540 EXPECT_EQ(1, new_beta->dim(0).value());
1541 EXPECT_EQ(4, new_beta->dim(1).value());
1542 EXPECT_EQ(4, new_beta->dim(2).value());
1543 EXPECT_EQ(channel_size, new_beta->dim(3).value());
1546 TEST(ConvertNCHWToNHWC, SubScalar)
1551 run_phase(&g.g, false, false);
1553 auto input_succs = loco::succs(g.input);
1554 EXPECT_EQ(1, input_succs.size());
1555 check_post_trans(*input_succs.begin());
1557 check_pre_trans(g.sub->y());
1559 auto add_succs = loco::succs(g.sub);
1560 EXPECT_EQ(1, add_succs.size());
1561 check_post_trans(*add_succs.begin());
1563 auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
1564 EXPECT_NE(nullptr, new_beta);
1565 EXPECT_EQ(1, new_beta->rank());
1567 check_pre_trans(g.output->from());