void ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node,
ConverterContext *context) const
{
- std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
- mir::Graph *graph = context->getGraph();
- assert((onnx_node.attribute_size() == 1) &&
- (onnx_node.attribute(0).type() == ::onnx::AttributeProto_AttributeType_TENSOR) &&
- (onnx_node.attribute(0).tensors_size() == 0));
- assert(onnx_node.attribute(0).name() == "value");
- auto name = onnx_node.output(0);
- auto &onnx_tensor = onnx_node.attribute(0).t();
+ const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+ if (opset_version >= 11)
+ convertV11(onnx_node, context);
+ else if (opset_version >= 9)
+ convertV9(onnx_node, context);
+ else if (opset_version >= 1)
+ convertV1(onnx_node, context);
+ else
+ throw std::runtime_error("Not supported opset version on Constant operation!");
+}
+
+void ConstantNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ const auto *value_attr = findAttribute(onnx_node, "value");
+ if (value_attr == nullptr)
+ throw std::runtime_error("Not enough value attribute in Constant operation!");
+ assert(value_attr->type() == onnx::AttributeProto_AttributeType_TENSOR);
+
+ const auto &name = onnx_node.output(0);
+ const auto &onnx_tensor = value_attr->t();
auto mir_tensor = createTensor(&onnx_tensor);
- // TODO check right removing input_tensors
- // input_tensors.insert(std::make_pair(name, mir_tensor));
+
+ mir::Graph *graph = context->getGraph();
auto result = graph->create<mir::ops::ConstantOp>(name, mir_tensor)->getOutput(0);
context->setNodeOutputs(onnx_node, {result});
}
+void ConstantNodeConverter::convertV9(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ // Since version 9 Constant operation support other types contained in tensor
+ convertV1(onnx_node, context);
+}
+
+void ConstantNodeConverter::convertV11(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
+{
+ const auto *value_attr = findAttribute(onnx_node, "value");
+ const auto *sparse_value_attr = findAttribute(onnx_node, "sparse_value");
+ if (value_attr == nullptr && sparse_value_attr == nullptr)
+ throw std::runtime_error("Not enough attributes in Constant operation!");
+
+ if (value_attr != nullptr)
+ return convertV9(onnx_node, context);
+
+ if (sparse_value_attr != nullptr)
+ throw std::runtime_error("Not supported sparse_tensor in Constant operation!");
+}
+
} // namespace mir_onnx