[mir_onnx] Support versioning in Add operation (#6653)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Tue, 20 Aug 2019 14:30:24 +0000 (17:30 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 20 Aug 2019 14:30:24 +0000 (17:30 +0300)
* Implemented converter versioning for Add operation, known versions: 1, 6, 7

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir-onnx-importer/Op/Add.cpp
compiler/mir-onnx-importer/Op/Add.h

index eef2d9b..a9542de 100644 (file)
@@ -25,6 +25,35 @@ namespace mir_onnx
 
 void AddNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+  if (opset_version >= 7)
+    convertV7(onnx_node, context);
+  else if (opset_version >= 6)
+    convertV6(onnx_node, context);
+  else if (opset_version >= 1)
+    convertV1(onnx_node, context);
+  else
+    throw std::runtime_error("Not supported opset version on Add operation!");
+}
+
+void AddNodeConverter::convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const
+{
+  // consumed_inputs attribute not used
+  convertV6(onnx_node, context);
+}
+
+void AddNodeConverter::convertV6(const onnx::NodeProto &onnx_node, ConverterContext *context) const
+{
+  // broadcast attribute not used
+  const auto *axis = findAttribute(onnx_node, "axis");
+  if (axis != nullptr)
+    throw std::runtime_error("Not supported axis attribute in Add operation!");
+
+  convertV7(onnx_node, context);
+}
+
+void AddNodeConverter::convertV7(const onnx::NodeProto &onnx_node, ConverterContext *context) const
+{
   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
   mir::Graph *graph = context->getGraph();
 
index 49217e2..0edc028 100644 (file)
@@ -23,6 +23,11 @@ class AddNodeConverter : public NodeConverter
 {
 public:
   void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+  void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV6(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV7(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
 };
 
 } // namespace mir_onnx