[mir_onnx] BatchNormalization operation versioning (#6692)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Tue, 20 Aug 2019 12:54:06 +0000 (15:54 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 20 Aug 2019 12:54:06 +0000 (15:54 +0300)
* Supported V1, V6, V7, V9 batch normalization versions

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir-onnx-importer/Op/BatchNormalization.cpp
compiler/mir-onnx-importer/Op/BatchNormalization.h

index 6871c6c..a6baf5e 100644 (file)
@@ -33,6 +33,51 @@ namespace mir_onnx
 void BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
                                               ConverterContext *context) const
 {
+  const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+  if (opset_version >= 9)
+    convertV9(onnx_node, context);
+  else if (opset_version >= 7)
+    convertV7(onnx_node, context);
+  else if (opset_version >= 6)
+    convertV6(onnx_node, context);
+  else if (opset_version >= 1)
+    convertV1(onnx_node, context);
+  else
+    throw std::runtime_error("Not supported opset version on BatchNormalization operation!");
+}
+
+void BatchNormalizationNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+                                                ConverterContext *context) const
+{
+  // consumed_inputs attribute not used
+  convertV6(onnx_node, context);
+}
+
+void BatchNormalizationNodeConverter::convertV6(const onnx::NodeProto &onnx_node,
+                                                ConverterContext *context) const
+{
+  const auto is_test = getIntAttribute(onnx_node, "is_test", 0);
+  if (is_test != 0)
+    throw std::runtime_error("Not supported is_test attribute!");
+
+  convertV7(onnx_node, context);
+}
+
+void BatchNormalizationNodeConverter::convertV7(const onnx::NodeProto &onnx_node,
+                                                ConverterContext *context) const
+{
+  const auto spatial = getIntAttribute(onnx_node, "spatial", 1);
+  if (spatial != 1)
+    throw std::runtime_error("Not supported spatial attribute!");
+
+  convertV9(onnx_node, context);
+}
+
+void BatchNormalizationNodeConverter::convertV9(const onnx::NodeProto &onnx_node,
+                                                ConverterContext *context) const
+{
+  // momentum attrribute used only for learning
+
   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
   mir::Graph *graph = context->getGraph();
 
index 4d6e3fd..364493a 100644 (file)
@@ -23,6 +23,12 @@ class BatchNormalizationNodeConverter : public NodeConverter
 {
 public:
   void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+  void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV6(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV7(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV9(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
 };
 
 } // namespace mir_onnx