From 5d595890e655499f8a995ee8af5360c335c008ad Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Tue, 20 Aug 2019 17:01:22 +0300
Subject: [PATCH] [mir_onnx] Concat operation versioning (#6693)
* Supported V1, V4 concatenation versions
Signed-off-by: Pavel Iliutchenko
---
compiler/mir-onnx-importer/Op/Concat.cpp | 25 +++++++++++++++++++++++++
compiler/mir-onnx-importer/Op/Concat.h | 4 ++++
2 files changed, 29 insertions(+)
diff --git a/compiler/mir-onnx-importer/Op/Concat.cpp b/compiler/mir-onnx-importer/Op/Concat.cpp
index 6218542..59ae590 100644
--- a/compiler/mir-onnx-importer/Op/Concat.cpp
+++ b/compiler/mir-onnx-importer/Op/Concat.cpp
@@ -25,9 +25,34 @@ namespace mir_onnx
void ConcatNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+ if (opset_version >= 4)
+ convertV4(onnx_node, context);
+ else if (opset_version >= 1)
+ convertV1(onnx_node, context);
+ else
+ throw std::runtime_error("Not supported opset version on Add operation!");
+}
+
+void ConcatNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
std::vector inputs = context->getNodeInputs(onnx_node);
mir::Graph *graph = context->getGraph();
+ const auto axis = getIntAttribute(onnx_node, "axis", 1);
+
+ auto result = createOp(graph, inputs, axis)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
+}
+
+void ConcatNodeConverter::convertV4(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ std::vector inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+ // From version 4 axis attribute is required
auto attr = findAttribute(onnx_node, "axis");
if (!attr)
throw std::runtime_error("Attribute axis is required!");
diff --git a/compiler/mir-onnx-importer/Op/Concat.h b/compiler/mir-onnx-importer/Op/Concat.h
index 5ca8822..b246aea 100644
--- a/compiler/mir-onnx-importer/Op/Concat.h
+++ b/compiler/mir-onnx-importer/Op/Concat.h
@@ -23,6 +23,10 @@ class ConcatNodeConverter : public NodeConverter
{
public:
void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+ void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+ void convertV4(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
};
} // namespace mir_onnx
--
2.7.4