2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
19 #include "helpers/CreateCircleConst.h"
21 #include <loco/IR/DataTypeTraits.h>
23 #include <luci/IR/CircleNodes.h>
24 #include <luci/Profile/CircleNodeOrigin.h>
27 #include <oops/InternalExn.h>
29 #include <flatbuffers/flexbuffers.h>
34 bool resolve_matmul(luci::CircleCustom *cop)
36 #define CHECK_OR_FALSE(condition) \
39 #define CHECK_OR_THROW(condition, message) \
41 INTERNAL_EXN(message);
43 auto graph = cop->graph();
44 const std::vector<uint8_t> custom_options = cop->custom_options();
45 auto map = flexbuffers::GetRoot(custom_options).AsMap();
46 const auto U8 = loco::DataType::U8;
47 const auto S16 = loco::DataType::S16;
48 const auto S32 = loco::DataType::S32;
49 const auto FLOAT32 = loco::DataType::FLOAT32;
51 auto name = cop->name();
52 assert(name.length() > 0);
54 bool transpose_a = map["transpose_a"].AsBool();
55 bool transpose_b = map["transpose_b"].AsBool();
57 loco::Node *lhs = cop->inputs(0);
58 loco::Node *rhs = cop->inputs(1);
60 // Check that the type of the first input is known
61 auto lhs_dtype = loco::must_cast<luci::CircleNode *>(cop->inputs(0))->dtype();
62 CHECK_OR_FALSE(lhs_dtype != loco::DataType::Unknown);
64 // If transpose of first input is requested, its shape must be known
65 auto circle_lhs = loco::must_cast<luci::CircleNode *>(lhs);
66 CHECK_OR_FALSE(!transpose_a || circle_lhs->shape_status() == luci::ShapeStatus::VALID);
67 // and its rank should be at least 2
68 CHECK_OR_FALSE(!transpose_a || circle_lhs->rank() >= 2);
69 // Check that the shape of the 2nd input is known
70 auto circle_rhs = loco::must_cast<luci::CircleNode *>(rhs);
71 CHECK_OR_FALSE(circle_rhs->shape_status() == luci::ShapeStatus::VALID);
72 // TODO as of 06/23/20 TFLite only supports rank 2 for 2nd input. Fix this once that changes!
73 CHECK_OR_FALSE(circle_rhs->rank() == 2);
74 // Check that input data type is supported
75 CHECK_OR_THROW(lhs_dtype == U8 || lhs_dtype == S16 || lhs_dtype == FLOAT32,
76 "Only UInt8, Int16 and Float32 data types are supported by MatMul");
80 // Create a permutation constant node
81 std::vector<int32_t> perm;
82 const auto lhs_rank = static_cast<int32_t>(circle_lhs->rank());
83 for (int32_t i = 0; i < lhs_rank; ++i)
85 std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]);
86 auto perm_node = luci::create_const_node(graph, S32, {circle_lhs->rank()}, perm);
87 perm_node->name(name + "/lhs/Transpose/perm");
88 // Now make a transpose node
89 auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
90 transpose_node->a(lhs);
91 transpose_node->perm(perm_node);
92 transpose_node->name(name + "/lhs/Transpose");
93 luci::add_origin(transpose_node, luci::get_origin(cop));
97 // Transpose the second input if needed. TFLite FullyConnected operator
98 // assumes the second input is in column-major order, but the input is
99 // in row-major order, thus we need to convert between them.
102 const std::vector<int32_t> perm{1, 0};
103 auto perm_node = luci::create_const_node(graph, S32, {2}, perm);
104 perm_node->name(name + "/rhs/Transpose/perm");
105 auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
106 transpose_node->a(rhs);
107 transpose_node->perm(perm_node);
108 transpose_node->name(name + "/rhs/Transpose");
109 luci::add_origin(transpose_node, luci::get_origin(cop));
110 rhs = transpose_node;
113 auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>();
115 auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>();
117 fc_node->weights(rhs);
118 fc_node->bias(empty_bias);
119 fc_node->fusedActivationFunction(luci::FusedActFunc::NONE);
120 fc_node->name(name + "/FullyConnected");
121 luci::add_origin(fc_node, luci::get_origin(cop));
123 auto customOut = loco::succs(cop);
124 assert(customOut.size() == 1);
125 replace(*customOut.begin()).with(fc_node);
134 bool ResolveCustomOpMatMulPass::run(loco::Graph *g)
136 bool changed = false;
137 for (auto node : loco::active_nodes(loco::output_nodes(g)))
139 auto cop = dynamic_cast<luci::CircleCustom *>(node);
143 if (cop->custom_code() != "MatMul")
146 if (!resolve_matmul(cop))