Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ResolveCustomOpMatMulPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
18
19 #include "helpers/CreateCircleConst.h"
20
21 #include <loco/IR/DataTypeTraits.h>
22
23 #include <luci/IR/CircleNodes.h>
24 #include <luci/Profile/CircleNodeOrigin.h>
25
26 #include <loco.h>
27 #include <oops/InternalExn.h>
28
29 #include <flatbuffers/flexbuffers.h>
30
31 namespace
32 {
33
34 bool resolve_matmul(luci::CircleCustom *cop)
35 {
36 #define CHECK_OR_FALSE(condition) \
37   if (not(condition))             \
38     return false;
39 #define CHECK_OR_THROW(condition, message) \
40   if (not(condition))                      \
41     INTERNAL_EXN(message);
42
43   auto graph = cop->graph();
44   const std::vector<uint8_t> custom_options = cop->custom_options();
45   auto map = flexbuffers::GetRoot(custom_options).AsMap();
46   const auto U8 = loco::DataType::U8;
47   const auto S16 = loco::DataType::S16;
48   const auto S32 = loco::DataType::S32;
49   const auto FLOAT32 = loco::DataType::FLOAT32;
50
51   auto name = cop->name();
52   assert(name.length() > 0);
53
54   bool transpose_a = map["transpose_a"].AsBool();
55   bool transpose_b = map["transpose_b"].AsBool();
56
57   loco::Node *lhs = cop->inputs(0);
58   loco::Node *rhs = cop->inputs(1);
59
60   // Check that the type of the first input is known
61   auto lhs_dtype = loco::must_cast<luci::CircleNode *>(cop->inputs(0))->dtype();
62   CHECK_OR_FALSE(lhs_dtype != loco::DataType::Unknown);
63
64   // If transpose of first input is requested, its shape must be known
65   auto circle_lhs = loco::must_cast<luci::CircleNode *>(lhs);
66   CHECK_OR_FALSE(!transpose_a || circle_lhs->shape_status() == luci::ShapeStatus::VALID);
67   // and its rank should be at least 2
68   CHECK_OR_FALSE(!transpose_a || circle_lhs->rank() >= 2);
69   // Check that the shape of the 2nd input is known
70   auto circle_rhs = loco::must_cast<luci::CircleNode *>(rhs);
71   CHECK_OR_FALSE(circle_rhs->shape_status() == luci::ShapeStatus::VALID);
72   // TODO as of 06/23/20 TFLite only supports rank 2 for 2nd input. Fix this once that changes!
73   CHECK_OR_FALSE(circle_rhs->rank() == 2);
74   // Check that input data type is supported
75   CHECK_OR_THROW(lhs_dtype == U8 || lhs_dtype == S16 || lhs_dtype == FLOAT32,
76                  "Only UInt8, Int16 and Float32 data types are supported by MatMul");
77
78   if (transpose_a)
79   {
80     // Create a permutation constant node
81     std::vector<int32_t> perm;
82     const auto lhs_rank = static_cast<int32_t>(circle_lhs->rank());
83     for (int32_t i = 0; i < lhs_rank; ++i)
84       perm.push_back(i);
85     std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]);
86     auto perm_node = luci::create_const_node(graph, S32, {circle_lhs->rank()}, perm);
87     perm_node->name(name + "/lhs/Transpose/perm");
88     // Now make a transpose node
89     auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
90     transpose_node->a(lhs);
91     transpose_node->perm(perm_node);
92     transpose_node->name(name + "/lhs/Transpose");
93     luci::add_origin(transpose_node, luci::get_origin(cop));
94     lhs = transpose_node;
95   }
96
97   // Transpose the second input if needed. TFLite FullyConnected operator
98   // assumes the second input is in column-major order, but the input is
99   // in row-major order, thus we need to convert between them.
100   if (!transpose_b)
101   {
102     const std::vector<int32_t> perm{1, 0};
103     auto perm_node = luci::create_const_node(graph, S32, {2}, perm);
104     perm_node->name(name + "/rhs/Transpose/perm");
105     auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
106     transpose_node->a(rhs);
107     transpose_node->perm(perm_node);
108     transpose_node->name(name + "/rhs/Transpose");
109     luci::add_origin(transpose_node, luci::get_origin(cop));
110     rhs = transpose_node;
111   }
112
113   auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>();
114
115   auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>();
116   fc_node->input(lhs);
117   fc_node->weights(rhs);
118   fc_node->bias(empty_bias);
119   fc_node->fusedActivationFunction(luci::FusedActFunc::NONE);
120   fc_node->name(name + "/FullyConnected");
121   luci::add_origin(fc_node, luci::get_origin(cop));
122
123   auto customOut = loco::succs(cop);
124   assert(customOut.size() == 1);
125   replace(*customOut.begin()).with(fc_node);
126   return true;
127 }
128
129 } // namespace
130
131 namespace luci
132 {
133
134 bool ResolveCustomOpMatMulPass::run(loco::Graph *g)
135 {
136   bool changed = false;
137   for (auto node : loco::active_nodes(loco::output_nodes(g)))
138   {
139     auto cop = dynamic_cast<luci::CircleCustom *>(node);
140     if (not cop)
141       continue;
142
143     if (cop->custom_code() != "MatMul")
144       continue;
145
146     if (!resolve_matmul(cop))
147       continue;
148
149     changed = true;
150   }
151
152   return changed;
153 }
154
155 } // namespace luci