From 5918f563e24c0ee23e66bf3f68a711cfddca25b2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 23 Aug 2019 15:04:25 +0300 Subject: [PATCH] [mir_onnx] MatMul operation support (#6797) * Added MatMulNodeConverter * Fix registration and cmake Signed-off-by: Pavel Iliutchenko --- compiler/mir-onnx-importer/CMakeLists.txt | 2 + compiler/mir-onnx-importer/ONNXOpRegistration.h | 3 +- compiler/mir-onnx-importer/Op/MatMul.cpp | 63 +++++++++++++++++++++++++ compiler/mir-onnx-importer/Op/MatMul.h | 37 +++++++++++++++ 4 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 compiler/mir-onnx-importer/Op/MatMul.cpp create mode 100644 compiler/mir-onnx-importer/Op/MatMul.h diff --git a/compiler/mir-onnx-importer/CMakeLists.txt b/compiler/mir-onnx-importer/CMakeLists.txt index cce8e00..e319621 100644 --- a/compiler/mir-onnx-importer/CMakeLists.txt +++ b/compiler/mir-onnx-importer/CMakeLists.txt @@ -49,6 +49,8 @@ set(MIR_ONNX_IMPORTER_SOURCES Op/Gemm.h Op/Identity.cpp Op/Identity.h + Op/MatMul.cpp + Op/MatMul.h Op/GlobalAveragePool.cpp Op/GlobalAveragePool.h Op/Max.cpp diff --git a/compiler/mir-onnx-importer/ONNXOpRegistration.h b/compiler/mir-onnx-importer/ONNXOpRegistration.h index 04da1b2..2d458a0 100644 --- a/compiler/mir-onnx-importer/ONNXOpRegistration.h +++ b/compiler/mir-onnx-importer/ONNXOpRegistration.h @@ -31,6 +31,7 @@ #include "Op/Gemm.h" #include "Op/GlobalAveragePool.h" #include "Op/Identity.h" +#include "Op/MatMul.h" #include "Op/Max.h" #include "Op/MaxPool.h" #include "Op/Mul.h" @@ -66,6 +67,7 @@ inline void registerSupportedOps() registry.registerConverter("GlobalAveragePool", stdex::make_unique()); registry.registerConverter("Identity", stdex::make_unique()); + registry.registerConverter("MatMul", stdex::make_unique()); registry.registerConverter("Max", stdex::make_unique()); registry.registerConverter("MaxPool", stdex::make_unique()); registry.registerConverter("Mul", stdex::make_unique()); @@ -77,7 +79,6 @@ inline void registerSupportedOps() registry.registerConverter("Sigmoid", stdex::make_unique()); registry.registerConverter("Softmax", stdex::make_unique()); registry.registerConverter("Sum", stdex::make_unique()); - registry.registerConverter("Transpose", stdex::make_unique()); registry.registerConverter("Unsqueeze", stdex::make_unique()); registry.registerConverter("Upsample", stdex::make_unique()); } diff --git a/compiler/mir-onnx-importer/Op/MatMul.cpp b/compiler/mir-onnx-importer/Op/MatMul.cpp new file mode 100644 index 0000000..b2107bc --- /dev/null +++ b/compiler/mir-onnx-importer/Op/MatMul.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MatMul.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/FullyConnectedOp.h" + +namespace mir_onnx +{ + +void MatMulNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const +{ + const auto opset_version = context->getOpsetVersion(onnx_node.domain()); + if (opset_version >= 9) + convertV9(onnx_node, context); + else if (opset_version >= 1) + convertV1(onnx_node, context); + else + throw std::runtime_error("Not supported opset version on MatMul operation!"); +} + +void MatMulNodeConverter::convertV1(const onnx::NodeProto &onnx_node, + ConverterContext *context) const +{ + std::vector inputs = context->getNodeInputs(onnx_node); + mir::Graph *graph = context->getGraph(); + + assert(inputs.size() == 2); + auto A = inputs[0]; + auto B = inputs[1]; + // MatMul multiply N-dimentional matrix + // FullyConnected layer multiply only 2-dimentional matrix + if (A->getShape().rank() != 2 || B->getShape().rank() != 2) + throw std::runtime_error("Supported only 2D matrix multiplying!"); + // Calculate A * B. + auto result = createOp(graph, A, B)->getOutput(0); + + context->setNodeOutputs(onnx_node, {result}); +} + +void MatMulNodeConverter::convertV9(const onnx::NodeProto &onnx_node, + ConverterContext *context) const +{ + // Other type constraints + convertV1(onnx_node, context); +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/MatMul.h b/compiler/mir-onnx-importer/Op/MatMul.h new file mode 100644 index 0000000..f418da8 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/MatMul.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MIR_ONNX_OP_MATMUL_H +#define MIR_ONNX_OP_MATMUL_H + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class MatMulNodeConverter : public NodeConverter +{ +public: + void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override; + +private: + void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const; + void convertV9(const onnx::NodeProto &onnx_node, ConverterContext *context) const; +}; + +} // namespace mir_onnx + +#endif // MIR_ONNX_OP_MATMUL_H -- 2.7.4