Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ResolveCustomOpMatMulPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
18
19 #include <loco/IR/DataTypeTraits.h>
20
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23
24 #include <loco.h>
25 #include <oops/InternalExn.h>
26
27 #include <flatbuffers/flexbuffers.h>
28
29 namespace
30 {
31
32 template <typename T>
33 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
34                                      const std::vector<uint32_t> &shape,
35                                      const std::vector<T> &values)
36 {
37   auto node = g->nodes()->create<luci::CircleConst>();
38   node->dtype(dtype);
39   node->rank(shape.size());
40
41   uint32_t size = 1;
42   for (uint32_t i = 0; i < shape.size(); ++i)
43   {
44     node->dim(i) = shape.at(i);
45     size *= shape.at(i);
46   }
47   node->shape_status(luci::ShapeStatus::VALID);
48
49 #define INIT_VALUES(DT)                          \
50   {                                              \
51     node->size<DT>(size);                        \
52     for (uint32_t i = 0; i < values.size(); ++i) \
53       node->at<DT>(i) = values[i];               \
54   }
55
56   switch (dtype)
57   {
58     case loco::DataType::U8:
59       INIT_VALUES(loco::DataType::U8);
60       break;
61     case loco::DataType::S16:
62       INIT_VALUES(loco::DataType::S16);
63       break;
64     case loco::DataType::S32:
65       INIT_VALUES(loco::DataType::S32);
66       break;
67     case loco::DataType::FLOAT32:
68       INIT_VALUES(loco::DataType::FLOAT32)
69       break;
70     default:
71       INTERNAL_EXN("create_const_node called with unsupported type");
72       break;
73   }
74   return node;
75 }
76
77 bool resolve_matmul(luci::CircleCustom *cop)
78 {
79 #define CHECK_OR_FALSE(condition) \
80   if (not(condition))             \
81     return false;
82 #define CHECK_OR_THROW(condition, message) \
83   if (not(condition))                      \
84     INTERNAL_EXN(message);
85
86   auto graph = cop->graph();
87   const std::vector<uint8_t> custom_options = cop->custom_options();
88   auto map = flexbuffers::GetRoot(custom_options).AsMap();
89   const auto U8 = loco::DataType::U8;
90   const auto S16 = loco::DataType::S16;
91   const auto S32 = loco::DataType::S32;
92   const auto FLOAT32 = loco::DataType::FLOAT32;
93
94   auto name = cop->name();
95   assert(name.length() > 0);
96
97   bool transpose_a = map["transpose_a"].AsBool();
98   bool transpose_b = map["transpose_b"].AsBool();
99
100   loco::Node *lhs = cop->inputs(0);
101   loco::Node *rhs = cop->inputs(1);
102
103   // Check that the type of the first input is known
104   auto lhs_dtype = loco::must_cast<luci::CircleNode *>(cop->inputs(0))->dtype();
105   CHECK_OR_FALSE(lhs_dtype != loco::DataType::Unknown);
106
107   // If transpose of first input is requested, its shape must be known
108   auto circle_lhs = loco::must_cast<luci::CircleNode *>(lhs);
109   CHECK_OR_FALSE(!transpose_a || circle_lhs->shape_status() == luci::ShapeStatus::VALID);
110   // and its rank should be at least 2
111   CHECK_OR_FALSE(!transpose_a || circle_lhs->rank() >= 2);
112   // Check that the shape of the 2nd input is known
113   auto circle_rhs = loco::must_cast<luci::CircleNode *>(rhs);
114   CHECK_OR_FALSE(circle_rhs->shape_status() == luci::ShapeStatus::VALID);
115   // TODO as of 06/23/20 TFLite only supports rank 2 for 2nd input. Fix this once that changes!
116   CHECK_OR_FALSE(circle_rhs->rank() == 2);
117   // Check that input data type is supported
118   CHECK_OR_THROW(lhs_dtype == U8 || lhs_dtype == S16 || lhs_dtype == FLOAT32,
119                  "Only UInt8, Int16 and Float32 data types are supported by MatMul");
120
121   if (transpose_a)
122   {
123     // Create a permutation constant node
124     std::vector<uint32_t> perm;
125     for (uint32_t i = 0; i < circle_lhs->rank(); ++i)
126       perm.push_back(i);
127     std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]);
128     auto perm_node = create_const_node(graph, S32, {circle_lhs->rank()}, perm);
129     perm_node->name(name + "/lhs/Transpose/perm");
130     // Now make a transpose node
131     auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
132     transpose_node->a(lhs);
133     transpose_node->perm(perm_node);
134     transpose_node->name(name + "/lhs/Transpose");
135     luci::add_origin(transpose_node, luci::get_origin(cop));
136     lhs = transpose_node;
137   }
138
139   // Transpose the second input if needed. TFLite FullyConnected operator
140   // assumes the second input is in column-major order, but the input is
141   // in row-major order, thus we need to convert between them.
142   if (!transpose_b)
143   {
144     const std::vector<uint32_t> perm{1, 0};
145     auto perm_node = create_const_node(graph, S32, {2}, perm);
146     perm_node->name(name + "/rhs/Transpose/perm");
147     auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
148     transpose_node->a(rhs);
149     transpose_node->perm(perm_node);
150     transpose_node->name(name + "/rhs/Transpose");
151     luci::add_origin(transpose_node, luci::get_origin(cop));
152     rhs = transpose_node;
153   }
154
155   auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>();
156   empty_bias->dtype(loco::DataType::FLOAT32); // Needed for type inference
157
158   auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>();
159   fc_node->input(lhs);
160   fc_node->weights(rhs);
161   fc_node->bias(empty_bias);
162   fc_node->fusedActivationFunction(luci::FusedActFunc::NONE);
163   fc_node->name(name + "/FullyConnected");
164   luci::add_origin(fc_node, luci::get_origin(cop));
165
166   auto customOut = loco::succs(cop);
167   assert(customOut.size() == 1);
168   replace(*customOut.begin()).with(fc_node);
169   return true;
170 }
171
172 } // namespace
173
174 namespace luci
175 {
176
177 bool ResolveCustomOpMatMulPass::run(loco::Graph *g)
178 {
179   bool changed = false;
180   for (auto node : loco::active_nodes(loco::output_nodes(g)))
181   {
182     auto cop = dynamic_cast<luci::CircleCustom *>(node);
183     if (not cop)
184       continue;
185
186     if (cop->custom_code() != "MatMul")
187       continue;
188
189     if (!resolve_matmul(cop))
190       continue;
191
192     changed = true;
193   }
194
195   return changed;
196 }
197
198 } // namespace luci