#include "Dropout.h"
-#include "ONNXHelpers.h"
+#include "AttributeHelpers.h"
namespace mir_onnx
{
void DropoutNodeConverter::convert(const onnx::NodeProto &onnx_node,
ConverterContext *context) const
{
+ const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+ if (opset_version >= 10)
+ convertV10(onnx_node, context);
+ else if (opset_version >= 7)
+ convertV7(onnx_node, context);
+ else if (opset_version >= 6)
+ convertV6(onnx_node, context);
+ else if (opset_version >= 1)
+ convertV1(onnx_node, context);
+ else
+ throw std::runtime_error("Not supported opset version on Dropout operation!");
+}
+
+void DropoutNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ // consumed_inputs attribute not used
+ convertV6(onnx_node, context);
+}
+
+void DropoutNodeConverter::convertV6(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ const auto is_test = getAttributeValue<std::int64_t>(onnx_node, "is_test", 0);
+ if (is_test == 0)
+ throw std::runtime_error("Not supported is_test attribute!");
+
+ convertV10(onnx_node, context);
+}
+
+void DropoutNodeConverter::convertV7(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ convertV10(onnx_node, context);
+}
+
+void DropoutNodeConverter::convertV10(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
- // This is a no-op in inference mode.
- auto result = inputs[0];
+ // ratio attribute not used
- context->setNodeOutputs(onnx_node, {result});
+ // This is a no-op in inference mode.
+ context->setNodeOutputs(onnx_node, {inputs[0]});
}
} // namespace mir_onnx
{
public:
void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+ void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+ void convertV6(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+ void convertV7(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+ void convertV10(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
};
} // namespace mir_onnx