Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / mir / src / mir_onnx_importer / Op / Gemm.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Gemm.h"
18
19 #include "ONNXHelpers.h"
20 #include "AttributeHelpers.h"
21
22 #include "mir/TensorUtil.h"
23
24 #include "mir/ops/AddOp.h"
25 #include "mir/ops/ConstantOp.h"
26 #include "mir/ops/FullyConnectedOp.h"
27 #include "mir/ops/MulOp.h"
28 #include "mir/ops/ReshapeOp.h"
29 #include "mir/ops/TransposeOp.h"
30
31 namespace mir_onnx
32 {
33
34 static void convertGemm(const onnx::NodeProto &onnx_node, ConverterContext *context)
35 {
36   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
37   mir::Graph *graph = context->getGraph();
38
39   assert(inputs.size() == 2 || inputs.size() == 3);
40
41   auto a = inputs[0];
42   auto b = inputs[1];
43   auto c = inputs.size() > 2 ? inputs[2] : nullptr;
44
45   // 1.0f is the default factor.
46   const auto alpha_val = getAttributeValue<float>(onnx_node, "alpha", 1.0f);
47   const auto beta_val = getAttributeValue<float>(onnx_node, "beta", 1.0f);
48
49   // 0 means that no transpose is needed. It is the default value.
50   const auto trans_a = getAttributeValue<std::int64_t>(onnx_node, "transA", 0);
51   const auto trans_b = getAttributeValue<std::int64_t>(onnx_node, "transB", 0);
52
53   // Transpose the A and B matrices as needed.
54   if (trans_a)
55     a = createOp<mir::ops::TransposeOp>(graph, a, std::vector<std::size_t>{1, 0})->getOutput(0);
56   if (trans_b)
57     b = createOp<mir::ops::TransposeOp>(graph, b, std::vector<std::size_t>{1, 0})->getOutput(0);
58
59   // Calculate A * B.
60   auto ab = createOp<mir::ops::FullyConnectedOp>(graph, a, b)->getOutput(0);
61
62   // Multiply A * B by the constant factor.
63   if (alpha_val != 1.0f)
64   {
65     mir::TensorVariant alpha_tensor({mir::DataType::FLOAT32, {}}, &alpha_val);
66     auto alpha = createOp<mir::ops::ConstantOp>(graph, alpha_tensor)->getOutput(0);
67     ab = createOp<mir::ops::MulOp>(graph, alpha, ab)->getOutput(0);
68   }
69
70   // If there are no third input, node is simple A*B multiplication
71   if (!c)
72   {
73     context->setNodeOutputs(onnx_node, {ab});
74     return;
75   }
76
77   // Multiply C by the constant factor.
78   if (beta_val != 1.0f)
79   {
80     mir::TensorVariant beta_tensor({mir::DataType::FLOAT32, {}}, &beta_val);
81     auto beta = createOp<mir::ops::ConstantOp>(graph, beta_tensor)->getOutput(0);
82     c = createOp<mir::ops::MulOp>(graph, beta, c)->getOutput(0);
83   }
84
85   // Calculate the result: alpha * A * B + beta * C.
86   auto result = createOp<mir::ops::AddOp>(graph, ab, c)->getOutput(0);
87
88   context->setNodeOutputs(onnx_node, {result});
89 }
90
91 void convertGemmV1(const onnx::NodeProto &onnx_node, ConverterContext *context)
92 {
93   return convertGemm(onnx_node, context);
94 }
95
96 void convertGemmV6(const onnx::NodeProto &onnx_node, ConverterContext *context)
97 {
98   // This version differs from V1: in description of C input (redundant text "can be inplace.")
99   return convertGemm(onnx_node, context);
100 }
101
102 void convertGemmV7(const onnx::NodeProto &onnx_node, ConverterContext *context)
103 {
104   // This version differs from V6: removed "broadcast" atribute
105   return convertGemm(onnx_node, context);
106 }
107
108 void convertGemmV9(const onnx::NodeProto &onnx_node, ConverterContext *context)
109 {
110   // This version differs from V7: added more supported types
111   return convertGemm(onnx_node, context);
112 }
113
114 void convertGemmV11(const onnx::NodeProto &onnx_node, ConverterContext *context)
115 {
116   // This operation differs from V11: input C is optional
117   return convertGemm(onnx_node, context);
118 }
119
120 } // namespace mir_onnx