Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ShuffleWeightTo16x1Float32Pass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <cassert>
22 #include <vector>
23
24 namespace
25 {
26
27 bool satisfy_precondition(luci::CircleFullyConnected *fc)
28 {
29   // check if it's already been shuffled
30   if (fc->weights_format() != luci::CircleFullyConnected::WeightsFormat::DEFAULT)
31     return false;
32
33   // check if its data type is FLOAT32
34   if (fc->dtype() != loco::DataType::FLOAT32)
35     return false;
36
37   auto weights = loco::must_cast<luci::CircleConst *>(fc->weights());
38   // rank must be 2
39   if (weights->rank() != 2)
40     return false;
41
42   // check if it has sparsity parameter
43   if (weights->sparsityparam())
44     return false;
45
46   // check if the number of row of FullyConnected's weight is a multiple of 16
47   const uint32_t MULTIPLE = 16;
48   uint32_t rows = weights->dim(0).value();
49   if (rows % MULTIPLE)
50     return false;
51
52   return true;
53 }
54
55 // get FullyConnected op vector that has same tensor
56 void get_FCs_having_same_tensor(std::vector<luci::CircleFullyConnected *> &fc_vec, loco::Graph *g,
57                                 luci::CircleFullyConnected *fc)
58 {
59   auto the_tensor = fc->weights();
60   for (auto node : loco::active_nodes(loco::output_nodes(g)))
61   {
62     auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
63     if (not fc)
64       continue;
65
66     if (fc->weights() == the_tensor)
67       fc_vec.push_back(fc);
68   }
69 }
70
71 luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc)
72 {
73   auto the_weights = loco::must_cast<luci::CircleConst *>(fc->weights());
74
75   // create CircleConst where shuffled data will be stored
76   luci::CircleConst *new_weights = fc->graph()->nodes()->create<luci::CircleConst>();
77   new_weights->dtype(loco::DataType::FLOAT32);
78   new_weights->size<loco::DataType::FLOAT32>(the_weights->size<loco::DataType::FLOAT32>());
79   new_weights->rank(the_weights->rank());
80   new_weights->shape_status(the_weights->shape_status());
81   for (uint32_t r = 0; r < new_weights->rank(); r++)
82   {
83     new_weights->dim(r).set(the_weights->dim(r).value());
84   }
85
86   // suffle weight
87   const uint32_t MULTIPLE = 16;
88   const uint32_t rows = the_weights->dim(0).value();
89   const uint32_t cols = the_weights->dim(1).value();
90   const uint32_t r_step = rows / MULTIPLE;
91   uint32_t index = 0;
92   for (uint32_t r = 0; r < r_step; r++)
93   {
94     for (uint32_t c = 0; c < cols; c++)
95     {
96       for (uint32_t i = 0; i < MULTIPLE; i++)
97       {
98         new_weights->at<loco::DataType::FLOAT32>(index++) =
99             the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c);
100       }
101     }
102   }
103
104   return new_weights;
105 }
106
107 } // namespace
108
109 namespace luci
110 {
111
112 bool ShuffleWeightTo16x1Float32Pass::run(loco::Graph *g)
113 {
114   bool changed = false;
115   for (auto node : loco::active_nodes(loco::output_nodes(g)))
116   {
117     auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
118     if (not fc)
119       continue;
120
121     if (not satisfy_precondition(fc))
122       continue;
123
124     std::vector<luci::CircleFullyConnected *> fc_vec;
125     get_FCs_having_same_tensor(fc_vec, g, fc);
126     auto new_weights = shuffle_weight(fc);
127
128     // replace to new weights
129     for (const auto fc : fc_vec)
130     {
131       fc->weights(new_weights);
132       fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32);
133     }
134   }
135
136   return changed;
137 }
138
139 } // namespace luci