void MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
ConverterContext *context) const
{
+ const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+ if (opset_version >= 10)
+ convertV10(onnx_node, context);
+ else if (opset_version >= 8)
+ convertV8(onnx_node, context);
+ else if (opset_version >= 1)
+ convertV1(onnx_node, context);
+ else
+ throw std::runtime_error("Not supported opset version on MaxPool operation!");
+}
+
+void MaxPoolNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ const auto auto_pad = getStringAttribute(onnx_node, "auto_pad", "NOTSET");
+ // auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID.
+ if (auto_pad != "NOTSET")
+ throw std::runtime_error("Supported only explicit padding!");
+
std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
mir::Graph *graph = context->getGraph();
- // TODO Set some asserts
- mir::ops::PoolOp::BorderType border_type;
- mir::ops::PoolOp::PoolingType pool_type;
+
+ mir::ops::PoolOp::BorderType border_type = mir::ops::PoolOp::BorderType::EMPTY;
+ mir::ops::PoolOp::PoolingType pool_type = mir::ops::PoolOp::PoolingType::MAX;
KernelStridesPadding cdata;
// Transpose ONNX NCHW to MIR NHWC
auto t_input = convertONNXToMIR(graph, inputs[0]);
- border_type = mir::ops::PoolOp::BorderType::EMPTY;
- pool_type = mir::ops::PoolOp::PoolingType::MAX;
getKernelStridesPadding(onnx_node, cdata);
auto result =
context->setNodeOutputs(onnx_node, {result});
}
+void MaxPoolNodeConverter::convertV8(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ const auto storage_order = getIntAttribute(onnx_node, "storage_order", 0);
+ if (storage_order != 0)
+ throw std::runtime_error("Not supported storage order attribute!");
+
+ convertV1(onnx_node, context);
+}
+
+void MaxPoolNodeConverter::convertV10(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ const auto ceil_mode = getIntAttribute(onnx_node, "ceil_mode", 0);
+ if (ceil_mode != 0)
+ throw std::runtime_error("Not supported ceil_mode attribute!");
+
+ const auto *dilations = findAttribute(onnx_node, "dilations");
+ if (dilations != nullptr)
+ throw std::runtime_error("Not supported dilations in Conv operation!");
+
+ convertV8(onnx_node, context);
+}
+
} // namespace mir_onnx