ONNXImporterImpl::ONNXImporterImpl(std::string filename) : _modelFilename(std::move(filename))
{
- _graph = stdex::make_unique<mir::Graph>();
registerSupportedOps();
}
+ONNXImporterImpl::~ONNXImporterImpl() = default;
+
static void loadModelFile(const std::string &filename, onnx::ModelProto *model)
{
GOOGLE_PROTOBUF_VERIFY_VERSION;
assert(tensor.has_name());
const auto mir_tensor = createTensor(&tensor);
auto *op = _graph->create<mir::ops::ConstantOp>(tensor.name(), mir_tensor);
- _tensorNameToOutput.emplace(tensor.name(), op->getOutput(0));
+ _context->setOutput(tensor.name(), op->getOutput(0));
}
- for (auto &input : graph.input())
+ for (const auto &input : graph.input())
{
assert(input.has_name());
- if (_tensorNameToOutput.find(input.name()) == _tensorNameToOutput.end())
+ if (_context->getOutput(input.name()) == nullptr)
{
const auto &onnx_input_shape = input.type().tensor_type().shape();
mir::Shape shape(onnx_input_shape.dim_size());
}
auto *op = _graph->create<mir::ops::InputOp>(input.name(), shape);
- _tensorNameToOutput.emplace(input.name(), op->getOutput(0));
+ _context->setOutput(input.name(), op->getOutput(0));
}
}
}
std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR()
{
+ _graph = stdex::make_unique<mir::Graph>();
+ _context = stdex::make_unique<ConverterContext>(_graph.get());
+
+ if (_model->ir_version() > onnx::IR_VERSION)
+ {
+ throw std::runtime_error("IR version " + std::to_string(_model->ir_version()) +
+ " is not supported yet.");
+ }
+
+ // Set Opset Version for each domain
+ for (const auto &op_set : _model->opset_import())
+ {
+ _context->setOpsetVersion(op_set.domain(), op_set.version());
+ }
+
createGraphInputs();
// Forming partially ordered computation graph
- for (auto &onnx_node : _model->graph().node())
+ for (const auto &onnx_node : _model->graph().node())
{
assert(onnx_node.has_op_type());
auto &op_type = onnx_node.op_type();
- auto &inputs = onnx_node.input();
-
- std::vector<mir::Operation::Output *> mir_inputs;
- std::vector<mir::Operation::Output *> mir_outputs;
-
- for (const auto &input_name : inputs)
- {
- if (!input_name.empty())
- {
- const auto mir_op_iter = _tensorNameToOutput.find(input_name);
- assert(mir_op_iter != _tensorNameToOutput.end());
- mir_inputs.emplace_back(mir_op_iter->second);
- }
- }
// Get converter
- const auto *node_converter = NodeConverterRegistry::getInstance().lookup(op_type);
+ auto *node_converter = NodeConverterRegistry::getInstance().lookup(op_type);
assert(node_converter);
- mir_outputs = node_converter->convert(onnx_node, mir_inputs, _graph.get());
- assert(!mir_outputs.empty());
- // Set outputs' names
- for (int i = 0; i < mir_outputs.size(); i++)
- {
- mir_outputs[i]->getNode()->setName(onnx_node.output(i));
- auto result = _tensorNameToOutput.emplace(onnx_node.output(i), mir_outputs[i]);
- if (!result.second)
- throw std::runtime_error("Name duplication: " + mir_outputs[i]->getNode()->getName());
- }
+ node_converter->convert(onnx_node, _context.get());
}
// Set graph outputs
const auto &outputs = _model->graph().output();
for (const auto &output : outputs)
{
assert(output.has_name());
- auto output_iter = _tensorNameToOutput.find(output.name());
- if (output_iter == _tensorNameToOutput.end())
+ auto mir_output = _context->getOutput(output.name());
+ if (mir_output == nullptr)
throw std::runtime_error("Bad output name!");
- _graph->create<mir::ops::OutputOp>(output.name(), output_iter->second);
- output_iter->second->getNode()->setName("");
+ _graph->create<mir::ops::OutputOp>(output.name(), mir_output);
+ mir_output->getNode()->setName("");
}
return std::move(_graph);
namespace mir_onnx
{
+class ConverterContext;
-class ONNXImporterImpl
+class ONNXImporterImpl final
{
public:
explicit ONNXImporterImpl(std::string filename);
-
+ ~ONNXImporterImpl();
/// @brief Load the model and convert it into a MIR Graph.
std::unique_ptr<mir::Graph> importModel();
void createGraphInputs();
void collectUnsupportedOps();
// Maps ONNX tensor names to corresponding MIR operation outputs.
- std::map<std::string, mir::Operation::Output *> _tensorNameToOutput;
std::string _modelFilename;
std::unique_ptr<onnx::ModelProto> _model;
+ std::unique_ptr<ConverterContext> _context;
std::unique_ptr<mir::Graph> _graph;
};
} // namespace mir_onnx
namespace mir_onnx
{
+class ConverterContext
+{
+public:
+ explicit ConverterContext(mir::Graph *graph) : _graph(graph) {}
+ ~ConverterContext() = default;
+
+ void setOpsetVersion(const std::string &domain, const int64_t opset_version)
+ {
+ _domainToOpsetVersion.emplace(domain, opset_version);
+ }
+
+ int64_t getOpsetVersion(const std::string &domain) const
+ {
+ auto iter = _domainToOpsetVersion.find(domain);
+ if (iter == _domainToOpsetVersion.end())
+ throw std::runtime_error("Didn't have domain " + domain + "!");
+ return iter->second;
+ }
+
+ void setOutput(const std::string &name, mir::Operation::Output *output)
+ {
+ auto result = _tensorNameToOutput.emplace(name, output);
+ if (!result.second)
+ throw std::runtime_error("Name duplication: " + output->getNode()->getName());
+ }
+
+ mir::Operation::Output *getOutput(const std::string &name) const
+ {
+ auto iter = _tensorNameToOutput.find(name);
+ if (iter == _tensorNameToOutput.end())
+ return nullptr;
+ else
+ return iter->second;
+ }
+
+ std::vector<mir::Operation::Output *> getNodeInputs(const onnx::NodeProto &onnx_node) const
+ {
+ const auto &input_names = onnx_node.input();
+ std::vector<mir::Operation::Output *> outputs;
+
+ for (const auto &input_name : input_names)
+ {
+ if (!input_name.empty())
+ {
+ auto *mir_output = getOutput(input_name);
+ assert(mir_output != nullptr);
+ outputs.emplace_back(mir_output);
+ }
+ }
+ return outputs;
+ }
+
+ void setNodeOutputs(const onnx::NodeProto &onnx_node,
+ const std::vector<mir::Operation::Output *> &outputs)
+ {
+ assert(!outputs.empty());
+ for (std::size_t i = 0; i < outputs.size(); ++i)
+ {
+ outputs[i]->getNode()->setName(onnx_node.output(i));
+ setOutput(onnx_node.output(i), outputs[i]);
+ }
+ }
+
+ mir::Graph *getGraph() const { return _graph; }
+
+private:
+ std::map<std::string, mir::Operation::Output *> _tensorNameToOutput;
+ std::map<std::string, int64_t> _domainToOpsetVersion;
+ mir::Graph *_graph;
+};
+
class NodeConverter
{
public:
- // TODO Change input arguments for converters
- // Maybe create graph context
- virtual std::vector<mir::Operation::Output *>
- convert(const onnx::NodeProto &onnx_node, const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const = 0;
+ virtual void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const = 0;
virtual ~NodeConverter() = default;
};
static NodeConverterRegistry &getInstance()
{
- static NodeConverterRegistry me;
- return me;
+ static NodeConverterRegistry instance;
+ return instance;
}
void registerConverter(const std::string &op_type, std::unique_ptr<NodeConverter> &&converter)
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-AddNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void AddNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
auto result = createOp<mir::ops::AddOp>(graph, inputs[0], inputs[1])->getOutput(0);
- return {result};
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class AddNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-AveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void AveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
// TODO Set some asserts
mir::ops::PoolOp::BorderType border_type = mir::ops::PoolOp::BorderType::EMPTY;
mir::ops::PoolOp::PoolingType pool_type = mir::ops::PoolOp::PoolingType::AVG;
auto result =
createOp<mir::ops::PoolOp>(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape,
- cdata.padding_before, cdata.padding_after, border_type);
- return {convertMIRToONNX(graph, result->getOutput(0))};
+ cdata.padding_before, cdata.padding_after, border_type)
+ ->getOutput(0);
+ result = convertMIRToONNX(graph, result);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class AveragePoolNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
assert(inputs.size() == 5);
auto input = inputs[0];
auto scale = inputs[1];
auto result = createOp<mir::ops::AddOp>(graph, input, mean)->getOutput(0);
result = createOp<mir::ops::MulOp>(graph, result, scale)->getOutput(0);
result = createOp<mir::ops::AddOp>(graph, result, bias)->getOutput(0);
- return {convertMIRToONNX(graph, result)};
+ result = convertMIRToONNX(graph, result);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class BatchNormalizationNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-ConcatNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void ConcatNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
auto attr = findAttribute(onnx_node, "axis");
if (!attr)
throw std::runtime_error("Attribute axis is required!");
int32_t axis = attr->i();
- auto result = createOp<mir::ops::ConcatOp>(graph, inputs, axis);
- return {result->getOutput(0)};
+
+ auto result = createOp<mir::ops::ConcatOp>(graph, inputs, axis)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class ConcatNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
assert((onnx_node.attribute_size() == 1) &&
(onnx_node.attribute(0).type() == ::onnx::AttributeProto_AttributeType_TENSOR) &&
(onnx_node.attribute(0).tensors_size() == 0));
auto mir_tensor = createTensor(&onnx_tensor);
// TODO check right removing input_tensors
// input_tensors.insert(std::make_pair(name, mir_tensor));
- auto op = graph->create<mir::ops::ConstantOp>(name, mir_tensor);
- return {op->getOutput(0)};
+ auto result = graph->create<mir::ops::ConstantOp>(name, mir_tensor)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class ConstantNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-ConvNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void ConvNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
assert(inputs.size() >= 2);
KernelStridesPadding cdata;
auto in_group_size = kernel_tensor.getShape().dim(2);
auto out_channels = kernel_tensor.getShape().dim(3);
- // 1 is the default number of groups in convolution
+ // 1 is the default number of groups.
int num_groups = getIntAttribute(onnx_node, "group", 1);
bool is_depthwise = (num_groups != 1) && (in_group_size == 1) && (out_channels == num_groups);
result = createOp<mir::ops::AddOp>(graph, result, inputs[2])->getOutput(0);
}
- return {convertMIRToONNX(graph, result)};
+ result = convertMIRToONNX(graph, result);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class ConvNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-DropoutNodeConverter::convert(const onnx::NodeProto &,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *) const
+void DropoutNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+
// This is a no-op in inference mode.
- return {inputs[0]};
+ auto result = inputs[0];
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class DropoutNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-GatherNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void GatherNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
- // 0 is the default axis number
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
+ // 0 is the default axis number.
int axis = getIntAttribute(onnx_node, "axis", 0);
- auto result = createOp<mir::ops::GatherOp>(graph, inputs[0], inputs[1], axis);
- return {result->getOutput(0)};
+
+ auto result = createOp<mir::ops::GatherOp>(graph, inputs[0], inputs[1], axis)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class GatherNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-GemmNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void GemmNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
assert(inputs.size() == 3);
auto a = inputs[0];
auto b = inputs[1];
// Calculate the result: alpha * A * B + beta * C.
auto result = createOp<mir::ops::AddOp>(graph, ab, c)->getOutput(0);
- return {result};
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class GemmNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-GivenTensorFillNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void GivenTensorFillNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
auto values_att = findAttribute(onnx_node, "values");
auto shape_att = findAttribute(onnx_node, "shape");
assert(values_att && shape_att);
mir::TensorVariant tensor(mir::DTYPE::FLOAT32, shape, values_att->floats().data());
// TODO Check right removing input_tensors
// input_tensors.insert(std::make_pair(onnx_node.output(0), tensor));
- auto result = createOp<mir::ops::ConstantOp>(graph, tensor);
- return {result->getOutput(0)};
+ auto result = createOp<mir::ops::ConstantOp>(graph, tensor)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class GivenTensorFillNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-GlobalAveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void GlobalAveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
mir::ops::PoolOp::BorderType border_type = mir::ops::PoolOp::BorderType::ZEROFILLED;
mir::ops::PoolOp::PoolingType pool_type = mir::ops::PoolOp::PoolingType::AVG;
auto result =
createOp<mir::ops::PoolOp>(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape,
- cdata.padding_before, cdata.padding_after, border_type);
- return {convertMIRToONNX(graph, result->getOutput(0))};
+ cdata.padding_before, cdata.padding_after, border_type)
+ ->getOutput(0);
+ result = convertMIRToONNX(graph, result);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class GlobalAveragePoolNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-MaxNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void MaxNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
auto result = createOp<mir::ops::MaxOp>(graph, inputs[0], inputs[1])->getOutput(0);
- return {result};
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class MaxNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
// TODO Set some asserts
mir::ops::PoolOp::BorderType border_type;
mir::ops::PoolOp::PoolingType pool_type;
auto result =
createOp<mir::ops::PoolOp>(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape,
- cdata.padding_before, cdata.padding_after, border_type);
- return {convertMIRToONNX(graph, result->getOutput(0))};
+ cdata.padding_before, cdata.padding_after, border_type)
+ ->getOutput(0);
+ result = convertMIRToONNX(graph, result);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class MaxPoolNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-MulNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void MulNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
auto result = createOp<mir::ops::MulOp>(graph, inputs[0], inputs[1])->getOutput(0);
- return {result};
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class MulNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-PadNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void PadNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
// 0.0f is the default value to be filled into padded cells.
float value = getFloatAttribute(onnx_node, "value", 0.0f);
auto pads_attr = findAttribute(onnx_node, "pads");
assert(pads_attr);
- // "constant" is the default mode
+ // "constant" is the default mode.
auto mode = getStringAttribute(onnx_node, "mode", "constant");
if (mode != "constant")
throw std::runtime_error("Not supported Pad mode attribue!");
vec[i] = pair;
}
auto result =
- createOp<mir::ops::PadOp>(graph, inputs[0], inputs[0]->getShape().rank(), vec, scalar);
- return {result->getOutput(0)};
+ createOp<mir::ops::PadOp>(graph, inputs[0], inputs[0]->getShape().rank(), vec, scalar)
+ ->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class PadNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-ReluNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void ReluNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
assert(inputs.size() == 1);
- auto result = createOp<mir::ops::ReluOp>(graph, inputs[0]);
- return {result->getOutput(0)};
+ auto result = createOp<mir::ops::ReluOp>(graph, inputs[0])->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class ReluNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-ReshapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void ReshapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
// The original shape
const auto &in_shape = inputs[0]->getShape();
i++;
}
auto out_shape = mir::Shape(shape_vector);
- auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape);
- return {result->getOutput(0)};
+ auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class ReshapeNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-ShapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void ShapeNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
const auto &input_shape = inputs[0]->getShape();
int size = input_shape.rank();
mir::Shape output_shape{size};
data[i] = input_shape.dim(i);
}
mir::TensorVariant tensor(mir::DTYPE::FLOAT32, output_shape, data.data());
- auto result = createOp<mir::ops::ConstantOp>(graph, tensor);
- return {result->getOutput(0)};
+ auto result = createOp<mir::ops::ConstantOp>(graph, tensor)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class ShapeNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-SigmoidNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void SigmoidNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
assert(inputs.size() == 1);
- auto result = createOp<mir::ops::SigmoidOp>(graph, inputs[0]);
- return {result->getOutput(0)};
+ auto result = createOp<mir::ops::SigmoidOp>(graph, inputs[0])->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class SigmoidNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-SoftmaxNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void SoftmaxNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
- // 1 is the default axis number
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
+ // 1 is the default axis number.
int axis = getIntAttribute(onnx_node, "axis", 1);
- auto result = createOp<mir::ops::SoftmaxOp>(graph, inputs[0], axis);
- return {result->getOutput(0)};
+
+ auto result = createOp<mir::ops::SoftmaxOp>(graph, inputs[0], axis)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class SoftmaxNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-SumNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void SumNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
assert(inputs.size() >= 1);
auto result = inputs[0];
result = createOp<mir::ops::AddOp>(graph, result, inputs[i])->getOutput(0);
}
- return {result};
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class SumNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-UnsqueezeNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void UnsqueezeNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
auto *axes = findAttribute(onnx_node, "axes");
assert(axes && axes->ints_size());
const mir::Shape &input_shape = inputs[0]->getShape();
j++;
}
}
- auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape);
- return {result->getOutput(0)};
+ auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class UnsqueezeNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx
namespace mir_onnx
{
-std::vector<mir::Operation::Output *>
-UpsampleNodeConverter::convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const
+void UpsampleNodeConverter::convert(const onnx::NodeProto &onnx_node,
+ ConverterContext *context) const
{
- // "nearest" is the default mode
+ std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+ mir::Graph *graph = context->getGraph();
+
+ // "nearest" is the default mode.
std::string mode = getStringAttribute(onnx_node, "mode", "nearest");
assert(mode == "nearest" && "Unsupported upscale mode!");
assert(scales_tensor.getShape().rank() == 1 && "Scales are a 1d tensor");
for (int i = 0; i < scales_tensor.getShape().numElements(); i++)
scales_vector[onnx2mir[i]] = scales_tensor.atOffset(i);
- return {convertMIRToONNX(
- graph,
+
+ auto result =
createOp<mir::ops::ResizeOp>(graph, convertONNXToMIR(graph, inputs[0]),
mir::ops::ResizeOp::ResizeMethod::nearestNeighbor, scales_vector)
- ->getOutput(0))};
+ ->getOutput(0);
+ result = convertMIRToONNX(graph, result);
+
+ context->setNodeOutputs(onnx_node, {result});
}
} // namespace mir_onnx
class UpsampleNodeConverter : public NodeConverter
{
public:
- std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
- const std::vector<mir::Operation::Output *> &inputs,
- mir::Graph *graph) const override;
+ void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
};
} // namespace mir_onnx