[mir_onnx] Introduce ConverterContext for using in NodeConverters (#6457)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Fri, 16 Aug 2019 09:07:57 +0000 (12:07 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 16 Aug 2019 09:07:57 +0000 (12:07 +0300)
* Implemented ConverterContext
* Fix NodeConverter interface to accept ConverterContext
* Fix all NodeConverters according to new interface
* Fix ONNX IR version checking.

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
49 files changed:
compiler/mir-onnx-importer/ONNXImporterImpl.cpp
compiler/mir-onnx-importer/ONNXImporterImpl.h
compiler/mir-onnx-importer/ONNXNodeConverterRegistry.h
compiler/mir-onnx-importer/Op/Add.cpp
compiler/mir-onnx-importer/Op/Add.h
compiler/mir-onnx-importer/Op/AveragePool.cpp
compiler/mir-onnx-importer/Op/AveragePool.h
compiler/mir-onnx-importer/Op/BatchNormalization.cpp
compiler/mir-onnx-importer/Op/BatchNormalization.h
compiler/mir-onnx-importer/Op/Concat.cpp
compiler/mir-onnx-importer/Op/Concat.h
compiler/mir-onnx-importer/Op/Constant.cpp
compiler/mir-onnx-importer/Op/Constant.h
compiler/mir-onnx-importer/Op/Conv.cpp
compiler/mir-onnx-importer/Op/Conv.h
compiler/mir-onnx-importer/Op/Dropout.cpp
compiler/mir-onnx-importer/Op/Dropout.h
compiler/mir-onnx-importer/Op/Gather.cpp
compiler/mir-onnx-importer/Op/Gather.h
compiler/mir-onnx-importer/Op/Gemm.cpp
compiler/mir-onnx-importer/Op/Gemm.h
compiler/mir-onnx-importer/Op/GivenTensorFill.cpp
compiler/mir-onnx-importer/Op/GivenTensorFill.h
compiler/mir-onnx-importer/Op/GlobalAveragePool.cpp
compiler/mir-onnx-importer/Op/GlobalAveragePool.h
compiler/mir-onnx-importer/Op/Max.cpp
compiler/mir-onnx-importer/Op/Max.h
compiler/mir-onnx-importer/Op/MaxPool.cpp
compiler/mir-onnx-importer/Op/MaxPool.h
compiler/mir-onnx-importer/Op/Mul.cpp
compiler/mir-onnx-importer/Op/Mul.h
compiler/mir-onnx-importer/Op/Pad.cpp
compiler/mir-onnx-importer/Op/Pad.h
compiler/mir-onnx-importer/Op/Relu.cpp
compiler/mir-onnx-importer/Op/Relu.h
compiler/mir-onnx-importer/Op/Reshape.cpp
compiler/mir-onnx-importer/Op/Reshape.h
compiler/mir-onnx-importer/Op/Shape.cpp
compiler/mir-onnx-importer/Op/Shape.h
compiler/mir-onnx-importer/Op/Sigmoid.cpp
compiler/mir-onnx-importer/Op/Sigmoid.h
compiler/mir-onnx-importer/Op/Softmax.cpp
compiler/mir-onnx-importer/Op/Softmax.h
compiler/mir-onnx-importer/Op/Sum.cpp
compiler/mir-onnx-importer/Op/Sum.h
compiler/mir-onnx-importer/Op/Unsqueeze.cpp
compiler/mir-onnx-importer/Op/Unsqueeze.h
compiler/mir-onnx-importer/Op/Upsample.cpp
compiler/mir-onnx-importer/Op/Upsample.h

index f36ba23..d1efd07 100644 (file)
@@ -39,10 +39,11 @@ namespace mir_onnx
 
 ONNXImporterImpl::ONNXImporterImpl(std::string filename) : _modelFilename(std::move(filename))
 {
-  _graph = stdex::make_unique<mir::Graph>();
   registerSupportedOps();
 }
 
+ONNXImporterImpl::~ONNXImporterImpl() = default;
+
 static void loadModelFile(const std::string &filename, onnx::ModelProto *model)
 {
   GOOGLE_PROTOBUF_VERIFY_VERSION;
@@ -110,14 +111,14 @@ void ONNXImporterImpl::createGraphInputs()
     assert(tensor.has_name());
     const auto mir_tensor = createTensor(&tensor);
     auto *op = _graph->create<mir::ops::ConstantOp>(tensor.name(), mir_tensor);
-    _tensorNameToOutput.emplace(tensor.name(), op->getOutput(0));
+    _context->setOutput(tensor.name(), op->getOutput(0));
   }
 
-  for (auto &input : graph.input())
+  for (const auto &input : graph.input())
   {
     assert(input.has_name());
 
-    if (_tensorNameToOutput.find(input.name()) == _tensorNameToOutput.end())
+    if (_context->getOutput(input.name()) == nullptr)
     {
       const auto &onnx_input_shape = input.type().tensor_type().shape();
       mir::Shape shape(onnx_input_shape.dim_size());
@@ -128,59 +129,51 @@ void ONNXImporterImpl::createGraphInputs()
       }
 
       auto *op = _graph->create<mir::ops::InputOp>(input.name(), shape);
-      _tensorNameToOutput.emplace(input.name(), op->getOutput(0));
+      _context->setOutput(input.name(), op->getOutput(0));
     }
   }
 }
 
 std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR()
 {
+  _graph = stdex::make_unique<mir::Graph>();
+  _context = stdex::make_unique<ConverterContext>(_graph.get());
+
+  if (_model->ir_version() > onnx::IR_VERSION)
+  {
+    throw std::runtime_error("IR version " + std::to_string(_model->ir_version()) +
+                             " is not supported yet.");
+  }
+
+  // Set Opset Version for each domain
+  for (const auto &op_set : _model->opset_import())
+  {
+    _context->setOpsetVersion(op_set.domain(), op_set.version());
+  }
+
   createGraphInputs();
 
   // Forming partially ordered computation graph
-  for (auto &onnx_node : _model->graph().node())
+  for (const auto &onnx_node : _model->graph().node())
   {
     assert(onnx_node.has_op_type());
     auto &op_type = onnx_node.op_type();
-    auto &inputs = onnx_node.input();
-
-    std::vector<mir::Operation::Output *> mir_inputs;
-    std::vector<mir::Operation::Output *> mir_outputs;
-
-    for (const auto &input_name : inputs)
-    {
-      if (!input_name.empty())
-      {
-        const auto mir_op_iter = _tensorNameToOutput.find(input_name);
-        assert(mir_op_iter != _tensorNameToOutput.end());
-        mir_inputs.emplace_back(mir_op_iter->second);
-      }
-    }
     // Get converter
-    const auto *node_converter = NodeConverterRegistry::getInstance().lookup(op_type);
+    auto *node_converter = NodeConverterRegistry::getInstance().lookup(op_type);
     assert(node_converter);
-    mir_outputs = node_converter->convert(onnx_node, mir_inputs, _graph.get());
-    assert(!mir_outputs.empty());
-    // Set outputs' names
-    for (int i = 0; i < mir_outputs.size(); i++)
-    {
-      mir_outputs[i]->getNode()->setName(onnx_node.output(i));
-      auto result = _tensorNameToOutput.emplace(onnx_node.output(i), mir_outputs[i]);
-      if (!result.second)
-        throw std::runtime_error("Name duplication: " + mir_outputs[i]->getNode()->getName());
-    }
+    node_converter->convert(onnx_node, _context.get());
   }
   // Set graph outputs
   const auto &outputs = _model->graph().output();
   for (const auto &output : outputs)
   {
     assert(output.has_name());
-    auto output_iter = _tensorNameToOutput.find(output.name());
-    if (output_iter == _tensorNameToOutput.end())
+    auto mir_output = _context->getOutput(output.name());
+    if (mir_output == nullptr)
       throw std::runtime_error("Bad output name!");
 
-    _graph->create<mir::ops::OutputOp>(output.name(), output_iter->second);
-    output_iter->second->getNode()->setName("");
+    _graph->create<mir::ops::OutputOp>(output.name(), mir_output);
+    mir_output->getNode()->setName("");
   }
 
   return std::move(_graph);
index ba07405..3914961 100644 (file)
 
 namespace mir_onnx
 {
+class ConverterContext;
 
-class ONNXImporterImpl
+class ONNXImporterImpl final
 {
 public:
   explicit ONNXImporterImpl(std::string filename);
-
+  ~ONNXImporterImpl();
   /// @brief Load the model and convert it into a MIR Graph.
   std::unique_ptr<mir::Graph> importModel();
 
@@ -41,9 +42,9 @@ private:
   void createGraphInputs();
   void collectUnsupportedOps();
   // Maps ONNX tensor names to corresponding MIR operation outputs.
-  std::map<std::string, mir::Operation::Output *> _tensorNameToOutput;
   std::string _modelFilename;
   std::unique_ptr<onnx::ModelProto> _model;
+  std::unique_ptr<ConverterContext> _context;
   std::unique_ptr<mir::Graph> _graph;
 };
 } // namespace mir_onnx
index eb8d472..e0f999c 100644 (file)
 namespace mir_onnx
 {
 
+class ConverterContext
+{
+public:
+  explicit ConverterContext(mir::Graph *graph) : _graph(graph) {}
+  ~ConverterContext() = default;
+
+  void setOpsetVersion(const std::string &domain, const int64_t opset_version)
+  {
+    _domainToOpsetVersion.emplace(domain, opset_version);
+  }
+
+  int64_t getOpsetVersion(const std::string &domain) const
+  {
+    auto iter = _domainToOpsetVersion.find(domain);
+    if (iter == _domainToOpsetVersion.end())
+      throw std::runtime_error("Didn't have domain " + domain + "!");
+    return iter->second;
+  }
+
+  void setOutput(const std::string &name, mir::Operation::Output *output)
+  {
+    auto result = _tensorNameToOutput.emplace(name, output);
+    if (!result.second)
+      throw std::runtime_error("Name duplication: " + output->getNode()->getName());
+  }
+
+  mir::Operation::Output *getOutput(const std::string &name) const
+  {
+    auto iter = _tensorNameToOutput.find(name);
+    if (iter == _tensorNameToOutput.end())
+      return nullptr;
+    else
+      return iter->second;
+  }
+
+  std::vector<mir::Operation::Output *> getNodeInputs(const onnx::NodeProto &onnx_node) const
+  {
+    const auto &input_names = onnx_node.input();
+    std::vector<mir::Operation::Output *> outputs;
+
+    for (const auto &input_name : input_names)
+    {
+      if (!input_name.empty())
+      {
+        auto *mir_output = getOutput(input_name);
+        assert(mir_output != nullptr);
+        outputs.emplace_back(mir_output);
+      }
+    }
+    return outputs;
+  }
+
+  void setNodeOutputs(const onnx::NodeProto &onnx_node,
+                      const std::vector<mir::Operation::Output *> &outputs)
+  {
+    assert(!outputs.empty());
+    for (std::size_t i = 0; i < outputs.size(); ++i)
+    {
+      outputs[i]->getNode()->setName(onnx_node.output(i));
+      setOutput(onnx_node.output(i), outputs[i]);
+    }
+  }
+
+  mir::Graph *getGraph() const { return _graph; }
+
+private:
+  std::map<std::string, mir::Operation::Output *> _tensorNameToOutput;
+  std::map<std::string, int64_t> _domainToOpsetVersion;
+  mir::Graph *_graph;
+};
+
 class NodeConverter
 {
 public:
-  // TODO Change input arguments for converters
-  // Maybe create graph context
-  virtual std::vector<mir::Operation::Output *>
-  convert(const onnx::NodeProto &onnx_node, const std::vector<mir::Operation::Output *> &inputs,
-          mir::Graph *graph) const = 0;
+  virtual void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const = 0;
   virtual ~NodeConverter() = default;
 };
 
@@ -57,8 +124,8 @@ public:
 
   static NodeConverterRegistry &getInstance()
   {
-    static NodeConverterRegistry me;
-    return me;
+    static NodeConverterRegistry instance;
+    return instance;
   }
 
   void registerConverter(const std::string &op_type, std::unique_ptr<NodeConverter> &&converter)
index a500ee6..eef2d9b 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-AddNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                          const std::vector<mir::Operation::Output *> &inputs,
-                          mir::Graph *graph) const
+void AddNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
   auto result = createOp<mir::ops::AddOp>(graph, inputs[0], inputs[1])->getOutput(0);
-  return {result};
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 67fb125..49217e2 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class AddNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 78df004..6bd20fd 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-AveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                                  const std::vector<mir::Operation::Output *> &inputs,
-                                  mir::Graph *graph) const
+void AveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                       ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   // TODO Set some asserts
   mir::ops::PoolOp::BorderType border_type = mir::ops::PoolOp::BorderType::EMPTY;
   mir::ops::PoolOp::PoolingType pool_type = mir::ops::PoolOp::PoolingType::AVG;
@@ -40,8 +40,11 @@ AveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
 
   auto result =
       createOp<mir::ops::PoolOp>(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape,
-                                 cdata.padding_before, cdata.padding_after, border_type);
-  return {convertMIRToONNX(graph, result->getOutput(0))};
+                                 cdata.padding_before, cdata.padding_after, border_type)
+          ->getOutput(0);
+  result = convertMIRToONNX(graph, result);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index f282281..d97677a 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class AveragePoolNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 97aa420..6871c6c 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                                         const std::vector<mir::Operation::Output *> &inputs,
-                                         mir::Graph *graph) const
+void BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                              ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
   assert(inputs.size() == 5);
   auto input = inputs[0];
   auto scale = inputs[1];
@@ -74,7 +75,9 @@ BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
   auto result = createOp<mir::ops::AddOp>(graph, input, mean)->getOutput(0);
   result = createOp<mir::ops::MulOp>(graph, result, scale)->getOutput(0);
   result = createOp<mir::ops::AddOp>(graph, result, bias)->getOutput(0);
-  return {convertMIRToONNX(graph, result)};
+  result = convertMIRToONNX(graph, result);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 79eb8e3..4d6e3fd 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class BatchNormalizationNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 08ac1ca..6218542 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-ConcatNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                             const std::vector<mir::Operation::Output *> &inputs,
-                             mir::Graph *graph) const
+void ConcatNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
   auto attr = findAttribute(onnx_node, "axis");
   if (!attr)
     throw std::runtime_error("Attribute axis is required!");
   int32_t axis = attr->i();
-  auto result = createOp<mir::ops::ConcatOp>(graph, inputs, axis);
-  return {result->getOutput(0)};
+
+  auto result = createOp<mir::ops::ConcatOp>(graph, inputs, axis)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 06af0b2..5ca8822 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class ConcatNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 0a1484b..dbf8e95 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                               const std::vector<mir::Operation::Output *> &inputs,
-                               mir::Graph *graph) const
+void ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                    ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   assert((onnx_node.attribute_size() == 1) &&
          (onnx_node.attribute(0).type() == ::onnx::AttributeProto_AttributeType_TENSOR) &&
          (onnx_node.attribute(0).tensors_size() == 0));
@@ -38,8 +38,9 @@ ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node,
   auto mir_tensor = createTensor(&onnx_tensor);
   // TODO check right removing input_tensors
   // input_tensors.insert(std::make_pair(name, mir_tensor));
-  auto op = graph->create<mir::ops::ConstantOp>(name, mir_tensor);
-  return {op->getOutput(0)};
+  auto result = graph->create<mir::ops::ConstantOp>(name, mir_tensor)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 92aa6e6..6aff4f2 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class ConstantNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index a41d445..cf21031 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-ConvNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                           const std::vector<mir::Operation::Output *> &inputs,
-                           mir::Graph *graph) const
+void ConvNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   assert(inputs.size() >= 2);
 
   KernelStridesPadding cdata;
@@ -46,7 +45,7 @@ ConvNodeConverter::convert(const onnx::NodeProto &onnx_node,
   auto in_group_size = kernel_tensor.getShape().dim(2);
   auto out_channels = kernel_tensor.getShape().dim(3);
 
-  // 1 is the default number of groups in convolution
+  // 1 is the default number of groups.
   int num_groups = getIntAttribute(onnx_node, "group", 1);
   bool is_depthwise = (num_groups != 1) && (in_group_size == 1) && (out_channels == num_groups);
 
@@ -79,7 +78,9 @@ ConvNodeConverter::convert(const onnx::NodeProto &onnx_node,
     result = createOp<mir::ops::AddOp>(graph, result, inputs[2])->getOutput(0);
   }
 
-  return {convertMIRToONNX(graph, result)};
+  result = convertMIRToONNX(graph, result);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 5849e65..e86655c 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class ConvNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 9395a99..9726b38 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-DropoutNodeConverter::convert(const onnx::NodeProto &,
-                              const std::vector<mir::Operation::Output *> &inputs,
-                              mir::Graph *) const
+void DropoutNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                   ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+
   // This is a no-op in inference mode.
-  return {inputs[0]};
+  auto result = inputs[0];
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index f57fa49..ad43bf6 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class DropoutNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index b6d04a0..425e478 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-GatherNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                             const std::vector<mir::Operation::Output *> &inputs,
-                             mir::Graph *graph) const
+void GatherNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
-  // 0 is the default axis number
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
+  // 0 is the default axis number.
   int axis = getIntAttribute(onnx_node, "axis", 0);
-  auto result = createOp<mir::ops::GatherOp>(graph, inputs[0], inputs[1], axis);
-  return {result->getOutput(0)};
+
+  auto result = createOp<mir::ops::GatherOp>(graph, inputs[0], inputs[1], axis)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 64c770a..e81c2af 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class GatherNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index a2d48df..7fd78cf 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-GemmNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                           const std::vector<mir::Operation::Output *> &inputs,
-                           mir::Graph *graph) const
+void GemmNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
   assert(inputs.size() == 3);
   auto a = inputs[0];
   auto b = inputs[1];
@@ -76,7 +76,7 @@ GemmNodeConverter::convert(const onnx::NodeProto &onnx_node,
   // Calculate the result: alpha * A * B + beta * C.
   auto result = createOp<mir::ops::AddOp>(graph, ab, c)->getOutput(0);
 
-  return {result};
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 461ebfd..c1388f2 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class GemmNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index c3febc1..7608b9a 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-GivenTensorFillNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                                      const std::vector<mir::Operation::Output *> &inputs,
-                                      mir::Graph *graph) const
+void GivenTensorFillNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                           ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   auto values_att = findAttribute(onnx_node, "values");
   auto shape_att = findAttribute(onnx_node, "shape");
   assert(values_att && shape_att);
@@ -40,8 +40,9 @@ GivenTensorFillNodeConverter::convert(const onnx::NodeProto &onnx_node,
   mir::TensorVariant tensor(mir::DTYPE::FLOAT32, shape, values_att->floats().data());
   // TODO Check right removing input_tensors
   // input_tensors.insert(std::make_pair(onnx_node.output(0), tensor));
-  auto result = createOp<mir::ops::ConstantOp>(graph, tensor);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::ConstantOp>(graph, tensor)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 806a4b6..0aafa87 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class GivenTensorFillNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index d4f9736..ff5e1f8 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-GlobalAveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                                        const std::vector<mir::Operation::Output *> &inputs,
-                                        mir::Graph *graph) const
+void GlobalAveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                             ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   mir::ops::PoolOp::BorderType border_type = mir::ops::PoolOp::BorderType::ZEROFILLED;
   mir::ops::PoolOp::PoolingType pool_type = mir::ops::PoolOp::PoolingType::AVG;
 
@@ -42,8 +42,11 @@ GlobalAveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
 
   auto result =
       createOp<mir::ops::PoolOp>(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape,
-                                 cdata.padding_before, cdata.padding_after, border_type);
-  return {convertMIRToONNX(graph, result->getOutput(0))};
+                                 cdata.padding_before, cdata.padding_after, border_type)
+          ->getOutput(0);
+  result = convertMIRToONNX(graph, result);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 48cfa8e..6cde973 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class GlobalAveragePoolNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 3a6ba69..158820c 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-MaxNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                          const std::vector<mir::Operation::Output *> &inputs,
-                          mir::Graph *graph) const
+void MaxNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   auto result = createOp<mir::ops::MaxOp>(graph, inputs[0], inputs[1])->getOutput(0);
-  return {result};
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 80797b6..e8e958b 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class MaxNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 4d415e8..c2b5509 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                              const std::vector<mir::Operation::Output *> &inputs,
-                              mir::Graph *graph) const
+void MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                   ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   // TODO Set some asserts
   mir::ops::PoolOp::BorderType border_type;
   mir::ops::PoolOp::PoolingType pool_type;
@@ -42,8 +42,11 @@ MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node,
 
   auto result =
       createOp<mir::ops::PoolOp>(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape,
-                                 cdata.padding_before, cdata.padding_after, border_type);
-  return {convertMIRToONNX(graph, result->getOutput(0))};
+                                 cdata.padding_before, cdata.padding_after, border_type)
+          ->getOutput(0);
+  result = convertMIRToONNX(graph, result);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index cf7058b..daf45f7 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class MaxPoolNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index f27af87..68358c7 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-MulNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                          const std::vector<mir::Operation::Output *> &inputs,
-                          mir::Graph *graph) const
+void MulNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   auto result = createOp<mir::ops::MulOp>(graph, inputs[0], inputs[1])->getOutput(0);
-  return {result};
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index a25cf23..cbf17e5 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class MulNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 48212a2..18ca95b 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-PadNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                          const std::vector<mir::Operation::Output *> &inputs,
-                          mir::Graph *graph) const
+void PadNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
   // 0.0f is the default value to be filled into padded cells.
   float value = getFloatAttribute(onnx_node, "value", 0.0f);
   auto pads_attr = findAttribute(onnx_node, "pads");
   assert(pads_attr);
-  // "constant" is the default mode
+  // "constant" is the default mode.
   auto mode = getStringAttribute(onnx_node, "mode", "constant");
   if (mode != "constant")
     throw std::runtime_error("Not supported Pad mode attribue!");
@@ -49,8 +49,10 @@ PadNodeConverter::convert(const onnx::NodeProto &onnx_node,
     vec[i] = pair;
   }
   auto result =
-      createOp<mir::ops::PadOp>(graph, inputs[0], inputs[0]->getShape().rank(), vec, scalar);
-  return {result->getOutput(0)};
+      createOp<mir::ops::PadOp>(graph, inputs[0], inputs[0]->getShape().rank(), vec, scalar)
+          ->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index a2801af..2242e90 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class PadNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 500e449..8037c7a 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-ReluNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                           const std::vector<mir::Operation::Output *> &inputs,
-                           mir::Graph *graph) const
+void ReluNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   assert(inputs.size() == 1);
-  auto result = createOp<mir::ops::ReluOp>(graph, inputs[0]);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::ReluOp>(graph, inputs[0])->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 4b2ee9e..f76499f 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class ReluNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index f764b3e..b2d68a5 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-ReshapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                              const std::vector<mir::Operation::Output *> &inputs,
-                              mir::Graph *graph) const
+void ReshapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                   ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   // The original shape
   const auto &in_shape = inputs[0]->getShape();
 
@@ -60,8 +60,9 @@ ReshapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
     i++;
   }
   auto out_shape = mir::Shape(shape_vector);
-  auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index c8558d8..6b4c326 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class ReshapeNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 7344d45..9e563ef 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-ShapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                            const std::vector<mir::Operation::Output *> &inputs,
-                            mir::Graph *graph) const
+void ShapeNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   const auto &input_shape = inputs[0]->getShape();
   int size = input_shape.rank();
   mir::Shape output_shape{size};
@@ -39,8 +38,9 @@ ShapeNodeConverter::convert(const onnx::NodeProto &onnx_node,
     data[i] = input_shape.dim(i);
   }
   mir::TensorVariant tensor(mir::DTYPE::FLOAT32, output_shape, data.data());
-  auto result = createOp<mir::ops::ConstantOp>(graph, tensor);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::ConstantOp>(graph, tensor)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 52ab97f..249130e 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class ShapeNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index e537b07..aa80592 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-SigmoidNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                              const std::vector<mir::Operation::Output *> &inputs,
-                              mir::Graph *graph) const
+void SigmoidNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                   ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   assert(inputs.size() == 1);
-  auto result = createOp<mir::ops::SigmoidOp>(graph, inputs[0]);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::SigmoidOp>(graph, inputs[0])->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index b738c23..c470ee0 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class SigmoidNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 88581cb..21e4c93 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-SoftmaxNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                              const std::vector<mir::Operation::Output *> &inputs,
-                              mir::Graph *graph) const
+void SoftmaxNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                   ConverterContext *context) const
 {
-  // 1 is the default axis number
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
+  // 1 is the default axis number.
   int axis = getIntAttribute(onnx_node, "axis", 1);
-  auto result = createOp<mir::ops::SoftmaxOp>(graph, inputs[0], axis);
-  return {result->getOutput(0)};
+
+  auto result = createOp<mir::ops::SoftmaxOp>(graph, inputs[0], axis)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 4600ee7..be43102 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class SoftmaxNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index c25f786..ddade1b 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-SumNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                          const std::vector<mir::Operation::Output *> &inputs,
-                          mir::Graph *graph) const
+void SumNodeConverter::convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   assert(inputs.size() >= 1);
 
   auto result = inputs[0];
@@ -36,7 +35,7 @@ SumNodeConverter::convert(const onnx::NodeProto &onnx_node,
     result = createOp<mir::ops::AddOp>(graph, result, inputs[i])->getOutput(0);
   }
 
-  return {result};
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index a9c64b8..a074d03 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class SumNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index 9487b90..0fcb307 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-UnsqueezeNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                                const std::vector<mir::Operation::Output *> &inputs,
-                                mir::Graph *graph) const
+void UnsqueezeNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                     ConverterContext *context) const
 {
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
   auto *axes = findAttribute(onnx_node, "axes");
   assert(axes && axes->ints_size());
   const mir::Shape &input_shape = inputs[0]->getShape();
@@ -48,8 +48,9 @@ UnsqueezeNodeConverter::convert(const onnx::NodeProto &onnx_node,
       j++;
     }
   }
-  auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index c9eae0b..3643527 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class UnsqueezeNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx
index cf2d7fe..74831c6 100644 (file)
 namespace mir_onnx
 {
 
-std::vector<mir::Operation::Output *>
-UpsampleNodeConverter::convert(const onnx::NodeProto &onnx_node,
-                               const std::vector<mir::Operation::Output *> &inputs,
-                               mir::Graph *graph) const
+void UpsampleNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                    ConverterContext *context) const
 {
-  // "nearest" is the default mode
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
+  // "nearest" is the default mode.
   std::string mode = getStringAttribute(onnx_node, "mode", "nearest");
   assert(mode == "nearest" && "Unsupported upscale mode!");
 
@@ -49,11 +50,14 @@ UpsampleNodeConverter::convert(const onnx::NodeProto &onnx_node,
   assert(scales_tensor.getShape().rank() == 1 && "Scales are a 1d tensor");
   for (int i = 0; i < scales_tensor.getShape().numElements(); i++)
     scales_vector[onnx2mir[i]] = scales_tensor.atOffset(i);
-  return {convertMIRToONNX(
-      graph,
+
+  auto result =
       createOp<mir::ops::ResizeOp>(graph, convertONNXToMIR(graph, inputs[0]),
                                    mir::ops::ResizeOp::ResizeMethod::nearestNeighbor, scales_vector)
-          ->getOutput(0))};
+          ->getOutput(0);
+  result = convertMIRToONNX(graph, result);
+
+  context->setNodeOutputs(onnx_node, {result});
 }
 
 } // namespace mir_onnx
index 9c2d2a5..ca258ee 100644 (file)
@@ -22,9 +22,7 @@ namespace mir_onnx
 class UpsampleNodeConverter : public NodeConverter
 {
 public:
-  std::vector<mir::Operation::Output *> convert(const onnx::NodeProto &onnx_node,
-                                                const std::vector<mir::Operation::Output *> &inputs,
-                                                mir::Graph *graph) const override;
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
 };
 
 } // namespace mir_onnx