[mir_onnx] Constant operation versioning (#6694)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Tue, 20 Aug 2019 14:28:27 +0000 (17:28 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 20 Aug 2019 14:28:27 +0000 (17:28 +0300)
* Supported V1, V9, v11(partially) versions

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir-onnx-importer/Op/Constant.cpp
compiler/mir-onnx-importer/Op/Constant.h

index dbf8e95..7c73cb6 100644 (file)
@@ -27,20 +27,55 @@ namespace mir_onnx
 void ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node,
                                     ConverterContext *context) const
 {
-  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
-  mir::Graph *graph = context->getGraph();
-  assert((onnx_node.attribute_size() == 1) &&
-         (onnx_node.attribute(0).type() == ::onnx::AttributeProto_AttributeType_TENSOR) &&
-         (onnx_node.attribute(0).tensors_size() == 0));
-  assert(onnx_node.attribute(0).name() == "value");
-  auto name = onnx_node.output(0);
-  auto &onnx_tensor = onnx_node.attribute(0).t();
+  const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+  if (opset_version >= 11)
+    convertV11(onnx_node, context);
+  else if (opset_version >= 9)
+    convertV9(onnx_node, context);
+  else if (opset_version >= 1)
+    convertV1(onnx_node, context);
+  else
+    throw std::runtime_error("Not supported opset version on Constant operation!");
+}
+
+void ConstantNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+                                      ConverterContext *context) const
+{
+  const auto *value_attr = findAttribute(onnx_node, "value");
+  if (value_attr == nullptr)
+    throw std::runtime_error("Not enough value attribute in Constant operation!");
+  assert(value_attr->type() == onnx::AttributeProto_AttributeType_TENSOR);
+
+  const auto &name = onnx_node.output(0);
+  const auto &onnx_tensor = value_attr->t();
   auto mir_tensor = createTensor(&onnx_tensor);
-  // TODO check right removing input_tensors
-  // input_tensors.insert(std::make_pair(name, mir_tensor));
+
+  mir::Graph *graph = context->getGraph();
   auto result = graph->create<mir::ops::ConstantOp>(name, mir_tensor)->getOutput(0);
 
   context->setNodeOutputs(onnx_node, {result});
 }
 
+void ConstantNodeConverter::convertV9(const onnx::NodeProto &onnx_node,
+                                      ConverterContext *context) const
+{
+  // Since version 9 Constant operation support other types contained in tensor
+  convertV1(onnx_node, context);
+}
+
+void ConstantNodeConverter::convertV11(const onnx::NodeProto &onnx_node,
+                                       ConverterContext *context) const
+{
+  const auto *value_attr = findAttribute(onnx_node, "value");
+  const auto *sparse_value_attr = findAttribute(onnx_node, "sparse_value");
+  if (value_attr == nullptr && sparse_value_attr == nullptr)
+    throw std::runtime_error("Not enough attributes in Constant operation!");
+
+  if (value_attr != nullptr)
+    return convertV9(onnx_node, context);
+
+  if (sparse_value_attr != nullptr)
+    throw std::runtime_error("Not supported sparse_tensor in Constant operation!");
+}
+
 } // namespace mir_onnx
index 6aff4f2..9b757bf 100644 (file)
@@ -23,6 +23,11 @@ class ConstantNodeConverter : public NodeConverter
 {
 public:
   void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+  void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV9(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV11(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
 };
 
 } // namespace mir_onnx