2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
19 #include <loco/IR/DataTypeTraits.h>
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
25 #include <oops/InternalExn.h>
27 #include <flatbuffers/flexbuffers.h>
33 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
34 const std::vector<uint32_t> &shape,
35 const std::vector<T> &values)
37 auto node = g->nodes()->create<luci::CircleConst>();
39 node->rank(shape.size());
42 for (uint32_t i = 0; i < shape.size(); ++i)
44 node->dim(i) = shape.at(i);
47 node->shape_status(luci::ShapeStatus::VALID);
49 #define INIT_VALUES(DT) \
51 node->size<DT>(size); \
52 for (uint32_t i = 0; i < values.size(); ++i) \
53 node->at<DT>(i) = values[i]; \
58 case loco::DataType::U8:
59 INIT_VALUES(loco::DataType::U8);
61 case loco::DataType::S16:
62 INIT_VALUES(loco::DataType::S16);
64 case loco::DataType::S32:
65 INIT_VALUES(loco::DataType::S32);
67 case loco::DataType::FLOAT32:
68 INIT_VALUES(loco::DataType::FLOAT32)
71 INTERNAL_EXN("create_const_node called with unsupported type");
77 bool resolve_matmul(luci::CircleCustom *cop)
79 #define CHECK_OR_FALSE(condition) \
82 #define CHECK_OR_THROW(condition, message) \
84 INTERNAL_EXN(message);
86 auto graph = cop->graph();
87 const std::vector<uint8_t> custom_options = cop->custom_options();
88 auto map = flexbuffers::GetRoot(custom_options).AsMap();
89 const auto U8 = loco::DataType::U8;
90 const auto S16 = loco::DataType::S16;
91 const auto S32 = loco::DataType::S32;
92 const auto FLOAT32 = loco::DataType::FLOAT32;
94 auto name = cop->name();
95 assert(name.length() > 0);
97 bool transpose_a = map["transpose_a"].AsBool();
98 bool transpose_b = map["transpose_b"].AsBool();
100 loco::Node *lhs = cop->inputs(0);
101 loco::Node *rhs = cop->inputs(1);
103 // Check that the type of the first input is known
104 auto lhs_dtype = loco::must_cast<luci::CircleNode *>(cop->inputs(0))->dtype();
105 CHECK_OR_FALSE(lhs_dtype != loco::DataType::Unknown);
107 // If transpose of first input is requested, its shape must be known
108 auto circle_lhs = loco::must_cast<luci::CircleNode *>(lhs);
109 CHECK_OR_FALSE(!transpose_a || circle_lhs->shape_status() == luci::ShapeStatus::VALID);
110 // and its rank should be at least 2
111 CHECK_OR_FALSE(!transpose_a || circle_lhs->rank() >= 2);
112 // Check that the shape of the 2nd input is known
113 auto circle_rhs = loco::must_cast<luci::CircleNode *>(rhs);
114 CHECK_OR_FALSE(circle_rhs->shape_status() == luci::ShapeStatus::VALID);
115 // TODO as of 06/23/20 TFLite only supports rank 2 for 2nd input. Fix this once that changes!
116 CHECK_OR_FALSE(circle_rhs->rank() == 2);
117 // Check that input data type is supported
118 CHECK_OR_THROW(lhs_dtype == U8 || lhs_dtype == S16 || lhs_dtype == FLOAT32,
119 "Only UInt8, Int16 and Float32 data types are supported by MatMul");
123 // Create a permutation constant node
124 std::vector<uint32_t> perm;
125 for (uint32_t i = 0; i < circle_lhs->rank(); ++i)
127 std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]);
128 auto perm_node = create_const_node(graph, S32, {circle_lhs->rank()}, perm);
129 perm_node->name(name + "/lhs/Transpose/perm");
130 // Now make a transpose node
131 auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
132 transpose_node->a(lhs);
133 transpose_node->perm(perm_node);
134 transpose_node->name(name + "/lhs/Transpose");
135 luci::add_origin(transpose_node, luci::get_origin(cop));
136 lhs = transpose_node;
139 // Transpose the second input if needed. TFLite FullyConnected operator
140 // assumes the second input is in column-major order, but the input is
141 // in row-major order, thus we need to convert between them.
144 const std::vector<uint32_t> perm{1, 0};
145 auto perm_node = create_const_node(graph, S32, {2}, perm);
146 perm_node->name(name + "/rhs/Transpose/perm");
147 auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
148 transpose_node->a(rhs);
149 transpose_node->perm(perm_node);
150 transpose_node->name(name + "/rhs/Transpose");
151 luci::add_origin(transpose_node, luci::get_origin(cop));
152 rhs = transpose_node;
155 auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>();
157 auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>();
159 fc_node->weights(rhs);
160 fc_node->bias(empty_bias);
161 fc_node->fusedActivationFunction(luci::FusedActFunc::NONE);
162 fc_node->name(name + "/FullyConnected");
163 luci::add_origin(fc_node, luci::get_origin(cop));
165 auto customOut = loco::succs(cop);
166 assert(customOut.size() == 1);
167 replace(*customOut.begin()).with(fc_node);
176 bool ResolveCustomOpMatMulPass::run(loco::Graph *g)
178 bool changed = false;
179 for (auto node : loco::active_nodes(loco::output_nodes(g)))
181 auto cop = dynamic_cast<luci::CircleCustom *>(node);
185 if (cop->custom_code() != "MatMul")
188 if (!resolve_matmul(cop))