void AddNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+ if (opset_version >= 7)
+ convertV7(onnx_node, context);
+ else if (opset_version >= 6)
+ convertV6(onnx_node, context);
+ else if (opset_version >= 1)
+ convertV1(onnx_node, context);
+ else
+ throw std::runtime_error("Not supported opset version on Add operation!");
+}
+
+void AddNodeConverter::convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const
+{
+ // consumed_inputs attribute not used
+ convertV6(onnx_node, context);
+}
+
+void AddNodeConverter::convertV6(const onnx::NodeProto &onnx_node, ConverterContext *context) const
+{
+ // broadcast attribute not used
+ const auto *axis = findAttribute(onnx_node, "axis");
+ if (axis != nullptr)
+ throw std::runtime_error("Not supported axis attribute in Add operation!");
+
+ convertV7(onnx_node, context);
+}
+
+void AddNodeConverter::convertV7(const onnx::NodeProto &onnx_node, ConverterContext *context) const
+{
std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
mir::Graph *graph = context->getGraph();
{
public:
void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+ void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+ void convertV6(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+ void convertV7(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
};
} // namespace mir_onnx