From 5d595890e655499f8a995ee8af5360c335c008ad Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 20 Aug 2019 17:01:22 +0300 Subject: [PATCH] [mir_onnx] Concat operation versioning (#6693) * Supported V1, V4 concatenation versions Signed-off-by: Pavel Iliutchenko --- compiler/mir-onnx-importer/Op/Concat.cpp | 25 +++++++++++++++++++++++++ compiler/mir-onnx-importer/Op/Concat.h | 4 ++++ 2 files changed, 29 insertions(+) diff --git a/compiler/mir-onnx-importer/Op/Concat.cpp b/compiler/mir-onnx-importer/Op/Concat.cpp index 6218542..59ae590 100644 --- a/compiler/mir-onnx-importer/Op/Concat.cpp +++ b/compiler/mir-onnx-importer/Op/Concat.cpp @@ -25,9 +25,34 @@ namespace mir_onnx void ConcatNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const { + const auto opset_version = context->getOpsetVersion(onnx_node.domain()); + if (opset_version >= 4) + convertV4(onnx_node, context); + else if (opset_version >= 1) + convertV1(onnx_node, context); + else + throw std::runtime_error("Not supported opset version on Add operation!"); +} + +void ConcatNodeConverter::convertV1(const onnx::NodeProto &onnx_node, + ConverterContext *context) const +{ std::vector inputs = context->getNodeInputs(onnx_node); mir::Graph *graph = context->getGraph(); + const auto axis = getIntAttribute(onnx_node, "axis", 1); + + auto result = createOp(graph, inputs, axis)->getOutput(0); + + context->setNodeOutputs(onnx_node, {result}); +} + +void ConcatNodeConverter::convertV4(const onnx::NodeProto &onnx_node, + ConverterContext *context) const +{ + std::vector inputs = context->getNodeInputs(onnx_node); + mir::Graph *graph = context->getGraph(); + // From version 4 axis attribute is required auto attr = findAttribute(onnx_node, "axis"); if (!attr) throw std::runtime_error("Attribute axis is required!"); diff --git a/compiler/mir-onnx-importer/Op/Concat.h b/compiler/mir-onnx-importer/Op/Concat.h index 5ca8822..b246aea 100644 --- a/compiler/mir-onnx-importer/Op/Concat.h +++ b/compiler/mir-onnx-importer/Op/Concat.h @@ -23,6 +23,10 @@ class ConcatNodeConverter : public NodeConverter { public: void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override; + +private: + void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const; + void convertV4(const onnx::NodeProto &onnx_node, ConverterContext *context) const; }; } // namespace mir_onnx -- 2.7.4