[mir_onnx] MaxPool operation versioning (#6696)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Tue, 20 Aug 2019 15:14:07 +0000 (18:14 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 20 Aug 2019 15:14:07 +0000 (18:14 +0300)
* Supported V1, V8, V10 maxpool versions

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir-onnx-importer/Op/MaxPool.cpp
compiler/mir-onnx-importer/Op/MaxPool.h

index c2b5509..0da6e35 100644 (file)
@@ -26,18 +26,35 @@ namespace mir_onnx
 void MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
                                    ConverterContext *context) const
 {
+  const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+  if (opset_version >= 10)
+    convertV10(onnx_node, context);
+  else if (opset_version >= 8)
+    convertV8(onnx_node, context);
+  else if (opset_version >= 1)
+    convertV1(onnx_node, context);
+  else
+    throw std::runtime_error("Not supported opset version on MaxPool operation!");
+}
+
+void MaxPoolNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+                                     ConverterContext *context) const
+{
+  const auto auto_pad = getStringAttribute(onnx_node, "auto_pad", "NOTSET");
+  // auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID.
+  if (auto_pad != "NOTSET")
+    throw std::runtime_error("Supported only explicit padding!");
+
   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
   mir::Graph *graph = context->getGraph();
-  // TODO Set some asserts
-  mir::ops::PoolOp::BorderType border_type;
-  mir::ops::PoolOp::PoolingType pool_type;
+
+  mir::ops::PoolOp::BorderType border_type = mir::ops::PoolOp::BorderType::EMPTY;
+  mir::ops::PoolOp::PoolingType pool_type = mir::ops::PoolOp::PoolingType::MAX;
 
   KernelStridesPadding cdata;
   // Transpose ONNX NCHW to MIR NHWC
   auto t_input = convertONNXToMIR(graph, inputs[0]);
 
-  border_type = mir::ops::PoolOp::BorderType::EMPTY;
-  pool_type = mir::ops::PoolOp::PoolingType::MAX;
   getKernelStridesPadding(onnx_node, cdata);
 
   auto result =
@@ -49,4 +66,28 @@ void MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
   context->setNodeOutputs(onnx_node, {result});
 }
 
+void MaxPoolNodeConverter::convertV8(const onnx::NodeProto &onnx_node,
+                                     ConverterContext *context) const
+{
+  const auto storage_order = getIntAttribute(onnx_node, "storage_order", 0);
+  if (storage_order != 0)
+    throw std::runtime_error("Not supported storage order attribute!");
+
+  convertV1(onnx_node, context);
+}
+
+void MaxPoolNodeConverter::convertV10(const onnx::NodeProto &onnx_node,
+                                      ConverterContext *context) const
+{
+  const auto ceil_mode = getIntAttribute(onnx_node, "ceil_mode", 0);
+  if (ceil_mode != 0)
+    throw std::runtime_error("Not supported ceil_mode attribute!");
+
+  const auto *dilations = findAttribute(onnx_node, "dilations");
+  if (dilations != nullptr)
+    throw std::runtime_error("Not supported dilations in Conv operation!");
+
+  convertV8(onnx_node, context);
+}
+
 } // namespace mir_onnx
index daf45f7..393a4e0 100644 (file)
@@ -23,6 +23,11 @@ class MaxPoolNodeConverter : public NodeConverter
 {
 public:
   void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+  void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV8(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV10(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
 };
 
 } // namespace mir_onnx