[mir_onnx] Dropout operation versioning (#6898)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Mon, 26 Aug 2019 18:47:27 +0000 (21:47 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 26 Aug 2019 18:47:27 +0000 (21:47 +0300)
* Supported V1, V6, V7, V10 dropout versions

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir-onnx-importer/Op/Dropout.cpp
compiler/mir-onnx-importer/Op/Dropout.h

index 7793317..c81c5ff 100644 (file)
@@ -16,7 +16,7 @@
 
 #include "Dropout.h"
 
-#include "ONNXHelpers.h"
+#include "AttributeHelpers.h"
 
 namespace mir_onnx
 {
@@ -24,12 +24,51 @@ namespace mir_onnx
 void DropoutNodeConverter::convert(const onnx::NodeProto &onnx_node,
                                    ConverterContext *context) const
 {
+  const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+  if (opset_version >= 10)
+    convertV10(onnx_node, context);
+  else if (opset_version >= 7)
+    convertV7(onnx_node, context);
+  else if (opset_version >= 6)
+    convertV6(onnx_node, context);
+  else if (opset_version >= 1)
+    convertV1(onnx_node, context);
+  else
+    throw std::runtime_error("Not supported opset version on Dropout operation!");
+}
+
+void DropoutNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+                                     ConverterContext *context) const
+{
+  // consumed_inputs attribute not used
+  convertV6(onnx_node, context);
+}
+
+void DropoutNodeConverter::convertV6(const onnx::NodeProto &onnx_node,
+                                     ConverterContext *context) const
+{
+  const auto is_test = getAttributeValue<std::int64_t>(onnx_node, "is_test", 0);
+  if (is_test == 0)
+    throw std::runtime_error("Not supported is_test attribute!");
+
+  convertV10(onnx_node, context);
+}
+
+void DropoutNodeConverter::convertV7(const onnx::NodeProto &onnx_node,
+                                     ConverterContext *context) const
+{
+  convertV10(onnx_node, context);
+}
+
+void DropoutNodeConverter::convertV10(const onnx::NodeProto &onnx_node,
+                                      ConverterContext *context) const
+{
   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
 
-  // This is a no-op in inference mode.
-  auto result = inputs[0];
+  // ratio attribute not used
 
-  context->setNodeOutputs(onnx_node, {result});
+  // This is a no-op in inference mode.
+  context->setNodeOutputs(onnx_node, {inputs[0]});
 }
 
 } // namespace mir_onnx
index 4fe691f..56291c9 100644 (file)
@@ -26,6 +26,12 @@ class DropoutNodeConverter : public NodeConverter
 {
 public:
   void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+  void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV6(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV7(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+  void convertV10(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
 };
 
 } // namespace mir_onnx