2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include "ONNXHelpers.h"
20 #include "AttributeHelpers.h"
22 #include "mir/TensorUtil.h"
24 #include "mir/ops/AddOp.h"
25 #include "mir/ops/ConstantOp.h"
26 #include "mir/ops/FullyConnectedOp.h"
27 #include "mir/ops/MulOp.h"
28 #include "mir/ops/ReshapeOp.h"
29 #include "mir/ops/TransposeOp.h"
34 static void convertGemm(const onnx::NodeProto &onnx_node, ConverterContext *context)
36 std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
37 mir::Graph *graph = context->getGraph();
39 assert(inputs.size() == 2 || inputs.size() == 3);
43 auto c = inputs.size() > 2 ? inputs[2] : nullptr;
45 // 1.0f is the default factor.
46 const auto alpha_val = getAttributeValue<float>(onnx_node, "alpha", 1.0f);
47 const auto beta_val = getAttributeValue<float>(onnx_node, "beta", 1.0f);
49 // 0 means that no transpose is needed. It is the default value.
50 const auto trans_a = getAttributeValue<std::int64_t>(onnx_node, "transA", 0);
51 const auto trans_b = getAttributeValue<std::int64_t>(onnx_node, "transB", 0);
53 // Transpose the A and B matrices as needed.
55 a = createOp<mir::ops::TransposeOp>(graph, a, std::vector<std::size_t>{1, 0})->getOutput(0);
57 b = createOp<mir::ops::TransposeOp>(graph, b, std::vector<std::size_t>{1, 0})->getOutput(0);
60 auto ab = createOp<mir::ops::FullyConnectedOp>(graph, a, b)->getOutput(0);
62 // Multiply A * B by the constant factor.
63 if (alpha_val != 1.0f)
65 mir::TensorVariant alpha_tensor({mir::DataType::FLOAT32, {}}, &alpha_val);
66 auto alpha = createOp<mir::ops::ConstantOp>(graph, alpha_tensor)->getOutput(0);
67 ab = createOp<mir::ops::MulOp>(graph, alpha, ab)->getOutput(0);
70 // If there are no third input, node is simple A*B multiplication
73 context->setNodeOutputs(onnx_node, {ab});
77 // Multiply C by the constant factor.
80 mir::TensorVariant beta_tensor({mir::DataType::FLOAT32, {}}, &beta_val);
81 auto beta = createOp<mir::ops::ConstantOp>(graph, beta_tensor)->getOutput(0);
82 c = createOp<mir::ops::MulOp>(graph, beta, c)->getOutput(0);
85 // Calculate the result: alpha * A * B + beta * C.
86 auto result = createOp<mir::ops::AddOp>(graph, ab, c)->getOutput(0);
88 context->setNodeOutputs(onnx_node, {result});
91 void convertGemmV1(const onnx::NodeProto &onnx_node, ConverterContext *context)
93 return convertGemm(onnx_node, context);
96 void convertGemmV6(const onnx::NodeProto &onnx_node, ConverterContext *context)
98 // This version differs from V1: in description of C input (redundant text "can be inplace.")
99 return convertGemm(onnx_node, context);
102 void convertGemmV7(const onnx::NodeProto &onnx_node, ConverterContext *context)
104 // This version differs from V6: removed "broadcast" atribute
105 return convertGemm(onnx_node, context);
108 void convertGemmV9(const onnx::NodeProto &onnx_node, ConverterContext *context)
110 // This version differs from V7: added more supported types
111 return convertGemm(onnx_node, context);
114 void convertGemmV11(const onnx::NodeProto &onnx_node, ConverterContext *context)
116 // This operation differs from V11: input C is optional
117 return convertGemm(onnx_node, context);
120 } // namespace mir_onnx