2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
19 #include <luci/IR/CircleNodes.h>
27 bool satisfy_precondition(luci::CircleFullyConnected *fc)
29 // check if it's already been shuffled
30 if (fc->weights_format() != luci::CircleFullyConnected::WeightsFormat::DEFAULT)
33 // check if its data type is FLOAT32
34 if (fc->dtype() != loco::DataType::FLOAT32)
37 auto weights = loco::must_cast<luci::CircleConst *>(fc->weights());
39 if (weights->rank() != 2)
42 // check if it has sparsity parameter
43 if (weights->sparsityparam())
46 // check if the number of row of FullyConnected's weight is a multiple of 16
47 const uint32_t MULTIPLE = 16;
48 uint32_t rows = weights->dim(0).value();
55 // get FullyConnected op vector that has same tensor
56 void get_FCs_having_same_tensor(std::vector<luci::CircleFullyConnected *> &fc_vec, loco::Graph *g,
57 luci::CircleFullyConnected *fc)
59 auto the_tensor = fc->weights();
60 for (auto node : loco::active_nodes(loco::output_nodes(g)))
62 auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
66 if (fc->weights() == the_tensor)
71 luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc)
73 auto the_weights = loco::must_cast<luci::CircleConst *>(fc->weights());
75 // create CircleConst where shuffled data will be stored
76 luci::CircleConst *new_weights = fc->graph()->nodes()->create<luci::CircleConst>();
77 new_weights->dtype(loco::DataType::FLOAT32);
78 new_weights->size<loco::DataType::FLOAT32>(the_weights->size<loco::DataType::FLOAT32>());
79 new_weights->rank(the_weights->rank());
80 new_weights->shape_status(the_weights->shape_status());
81 for (uint32_t r = 0; r < new_weights->rank(); r++)
83 new_weights->dim(r).set(the_weights->dim(r).value());
87 const uint32_t MULTIPLE = 16;
88 const uint32_t rows = the_weights->dim(0).value();
89 const uint32_t cols = the_weights->dim(1).value();
90 const uint32_t r_step = rows / MULTIPLE;
92 for (uint32_t r = 0; r < r_step; r++)
94 for (uint32_t c = 0; c < cols; c++)
96 for (uint32_t i = 0; i < MULTIPLE; i++)
98 new_weights->at<loco::DataType::FLOAT32>(index++) =
99 the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c);
112 bool ShuffleWeightTo16x1Float32Pass::run(loco::Graph *g)
114 bool changed = false;
115 for (auto node : loco::active_nodes(loco::output_nodes(g)))
117 auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
121 if (not satisfy_precondition(fc))
124 std::vector<luci::CircleFullyConnected *> fc_vec;
125 get_FCs_having_same_tensor(fc_vec, g, fc);
126 auto new_weights = shuffle_weight(fc);
128 // replace to new weights
129 for (const auto fc : fc_vec)
131 fc->weights(new_weights);
132 fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32);