[mir_onnx] Fix checking some attributes (#6796)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Wed, 21 Aug 2019 19:13:06 +0000 (22:13 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 21 Aug 2019 19:13:06 +0000 (22:13 +0300)
* Added checking that dilations is equal 1 in Conv and MaxPool
* Removed spatial attribute from BatchNormalization

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir-onnx-importer/Op/BatchNormalization.cpp
compiler/mir-onnx-importer/Op/Conv.cpp
compiler/mir-onnx-importer/Op/MaxPool.cpp

index 9cc02e3..c47b733 100644 (file)
@@ -67,9 +67,7 @@ void BatchNormalizationNodeConverter::convertV6(const onnx::NodeProto &onnx_node
 void BatchNormalizationNodeConverter::convertV7(const onnx::NodeProto &onnx_node,
                                                 ConverterContext *context) const
 {
-  const auto spatial = getAttributeValue<int64_t>(onnx_node, "spatial", 1);
-  if (spatial != 1)
-    throw std::runtime_error("Not supported spatial attribute!");
+  // spatial attribute used only for learning
 
   convertV9(onnx_node, context);
 }
index e25f6b1..af942fd 100644 (file)
@@ -47,7 +47,12 @@ void ConvNodeConverter::convertV1(const onnx::NodeProto &onnx_node, ConverterCon
 
   const auto *dilations = findAttribute(onnx_node, "dilations");
   if (dilations != nullptr)
-    throw std::runtime_error("Not supported dilations in Conv operation!");
+  {
+    // check default (=1) dilations on each spatial axis
+    for (auto index = 0; index < dilations->ints_size(); index++)
+      if (dilations->ints(index) != 1)
+        throw std::runtime_error("Not supported dilations in Conv operation!");
+  }
 
   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
   mir::Graph *graph = context->getGraph();
index 48e35eb..53d6b85 100644 (file)
@@ -86,7 +86,12 @@ void MaxPoolNodeConverter::convertV10(const onnx::NodeProto &onnx_node,
 
   const auto *dilations = findAttribute(onnx_node, "dilations");
   if (dilations != nullptr)
-    throw std::runtime_error("Not supported dilations in Conv operation!");
+  {
+    // check default (=1) dilations on each spatial axis
+    for (auto index = 0; index < dilations->ints_size(); index++)
+      if (dilations->ints(index) != 1)
+        throw std::runtime_error("Not supported dilations in MaxPool operation!");
+  }
 
   convertV8(onnx_node, context);
 }