Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ShuffleWeightTo16x1Float32Pass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <gtest/gtest.h>
22
23 void create_fc_net(loco::Graph *g)
24 {
25   assert(g);
26
27   const uint32_t ROW = 16;
28   const uint32_t COL = 2;
29   const uint32_t elements_num = ROW * COL;
30
31   // input
32   auto input = g->nodes()->create<luci::CircleInput>();
33   auto graph_input = g->inputs()->create();
34   input->index(graph_input->index());
35
36   // fc weights
37   auto weights = g->nodes()->create<luci::CircleConst>();
38   weights->dtype(loco::DataType::FLOAT32);
39   weights->size<loco::DataType::FLOAT32>(elements_num);
40   weights->rank(2);
41   weights->dim(0).set(ROW);
42   weights->dim(1).set(COL);
43   for (uint32_t idx = 0; idx < elements_num; idx++)
44   {
45     weights->at<loco::DataType::FLOAT32>(idx) = idx;
46   }
47
48   // fc
49   auto fc = g->nodes()->create<luci::CircleFullyConnected>();
50   fc->dtype(loco::DataType::FLOAT32);
51   fc->input(input);
52   fc->weights(weights);
53
54   // output
55   auto output = g->nodes()->create<luci::CircleOutput>();
56   output->from(fc);
57   auto graph_output = g->outputs()->create();
58   output->index(graph_output->index());
59 }
60
61 TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
62 {
63   auto graph = loco::make_graph();
64   create_fc_net(graph.get());
65
66   luci::CircleFullyConnected *fc_node = nullptr;
67   for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
68   {
69     auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
70     if (not fc)
71       continue;
72
73     fc_node = fc;
74     break;
75   }
76   ASSERT_NE(fc_node, nullptr);
77   auto weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
78   // before
79   ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
80   ASSERT_EQ(1, weights->at<loco::DataType::FLOAT32>(1));
81   ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(2));
82   ASSERT_EQ(3, weights->at<loco::DataType::FLOAT32>(3));
83   ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(4));
84   ASSERT_EQ(5, weights->at<loco::DataType::FLOAT32>(5));
85   ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(6));
86   ASSERT_EQ(7, weights->at<loco::DataType::FLOAT32>(7));
87   ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(8));
88   ASSERT_EQ(9, weights->at<loco::DataType::FLOAT32>(9));
89   ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(10));
90   ASSERT_EQ(11, weights->at<loco::DataType::FLOAT32>(11));
91   ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(12));
92   ASSERT_EQ(13, weights->at<loco::DataType::FLOAT32>(13));
93   ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(14));
94   ASSERT_EQ(15, weights->at<loco::DataType::FLOAT32>(15));
95
96   luci::ShuffleWeightTo16x1Float32Pass pass;
97   while (pass.run(graph.get()))
98     ;
99
100   weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
101   // after
102   ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
103   ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(1));
104   ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(2));
105   ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(3));
106   ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(4));
107   ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(5));
108   ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(6));
109   ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(7));
110   ASSERT_EQ(16, weights->at<loco::DataType::FLOAT32>(8));
111   ASSERT_EQ(18, weights->at<loco::DataType::FLOAT32>(9));
112   ASSERT_EQ(20, weights->at<loco::DataType::FLOAT32>(10));
113   ASSERT_EQ(22, weights->at<loco::DataType::FLOAT32>(11));
114   ASSERT_EQ(24, weights->at<loco::DataType::FLOAT32>(12));
115   ASSERT_EQ(26, weights->at<loco::DataType::FLOAT32>(13));
116   ASSERT_EQ(28, weights->at<loco::DataType::FLOAT32>(14));
117   ASSERT_EQ(30, weights->at<loco::DataType::FLOAT32>(15));
118 }