[nnc] Refactor ONNX importer to support operators with multiple outputs (#2831)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Mon, 14 Jan 2019 15:06:55 +0000 (18:06 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 14 Jan 2019 15:06:55 +0000 (18:06 +0300)
Make operator conversion methods accept and return vectors of IODescriptors.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h

index e874295..d37d970 100644 (file)
@@ -160,7 +160,7 @@ void ONNXImporterImpl::createGraphInputs() {
       const onnx::TensorProto* onnx_tensor = onnx_tensors[name];
       _inputTensors.insert(std::make_pair(name, createTensor(onnx_tensor)));
       auto constant = _graph->create<mir::ops::ConstantOp>(name, _inputTensors.at(name));
-      _tensorNameToPrevMirOp[name] = constant;
+      _tensorNameToIODescriptor[name] = constant->getOutput(0);
     } else {
       // We're dealing with graph input (assuming the picture only)
       auto onnx_input_shape = input.type().tensor_type().shape();
@@ -172,19 +172,20 @@ void ONNXImporterImpl::createGraphInputs() {
       }
       // TODO: Temporary solution!
       auto node = _graph->create<mir::ops::VariableOp>(name, shape);
-      _tensorNameToPrevMirOp[name] = node;
+      _tensorNameToIODescriptor[name] = node->getOutput(0);
     }
   }
 }
 
-void ONNXImporterImpl::dump(const std::vector<mir::Operation*>& inputs,
-                            const std::vector<mir::Operation*>& ops,
+void ONNXImporterImpl::dump(const std::vector<mir::IODescriptor>& input_descrs,
+                            const std::vector<mir::IODescriptor>& out_descrs,
                             const onnx::NodeProto& onnx_node) {
-  for (auto op : ops) {
+  for (auto out_descr : out_descrs) {
+    auto op = out_descr.op;
     std::cout << onnx_node.op_type() << " '" << op->getName() << "'";
-    if (inputs[0]->getNumInputs() > 0) {
+    if (input_descrs[0].op->getNumInputs() > 0) {
       std::cout << "Input Shape: ";
-      dumpShape(inputs[0]->getOutputShape(0));
+      dumpShape(input_descrs[0].op->getOutputShape(input_descrs[0].index));
     }
     std::cout << " Output Shape: ";
     dumpShape(op->getOutputShape(0));
@@ -249,91 +250,91 @@ mir::Graph *ONNXImporterImpl::createIR() {
     assert(onnx_node.has_op_type());
     auto op_type = onnx_node.op_type().c_str();
     // Fill inputs of the given node
-    std::vector<mir::Operation*> input_nodes(onnx_node.input_size());
+    std::vector<mir::IODescriptor> inputs(onnx_node.input_size());
     for (int i = 0; i < onnx_node.input_size(); i++) {
       auto& name = onnx_node.input(i);
-      assert(_tensorNameToPrevMirOp.find(name) != _tensorNameToPrevMirOp.end());
-      input_nodes[i] = _tensorNameToPrevMirOp[name];
+      assert(_tensorNameToIODescriptor.find(name) != _tensorNameToIODescriptor.end());
+      inputs[i] = _tensorNameToIODescriptor[name];
     }
 
-    std::vector<mir::Operation*> outputs;
+    std::vector<mir::IODescriptor> outputs;
     auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(op_type, onnx_node.op_type().size());
 
     switch (onnx_op_type->opCode) {
       case ONNXOpCode::opConv:
-        outputs = _opCreator.convertConv2D(input_nodes, onnx_node);
+        outputs = _opCreator.convertConv2D(inputs, onnx_node);
         break;
       case ONNXOpCode::opAdd:
-        outputs = _opCreator.convertElementwise(input_nodes, mir::ops::ElementwiseOp::OpType::add);
+        outputs = _opCreator.convertElementwise(inputs, mir::ops::ElementwiseOp::OpType::add);
         break;
       case ONNXOpCode::opGather:
-        outputs = _opCreator.convertGather(input_nodes, onnx_node);
+        outputs = _opCreator.convertGather(inputs, onnx_node);
         break;
       case ONNXOpCode::opGemm:
-        outputs = _opCreator.convertGemm(input_nodes, onnx_node);
+        outputs = _opCreator.convertGemm(inputs, onnx_node);
         break;
       case ONNXOpCode::opSum:
-        outputs = _opCreator.convertElementwise(input_nodes, mir::ops::ElementwiseOp::OpType::add);
+        outputs = _opCreator.convertElementwise(inputs, mir::ops::ElementwiseOp::OpType::add);
         break;
       case ONNXOpCode::opMul:
-        outputs = _opCreator.convertElementwise(input_nodes, mir::ops::ElementwiseOp::OpType::mul);
+        outputs = _opCreator.convertElementwise(inputs, mir::ops::ElementwiseOp::OpType::mul);
         break;
       case ONNXOpCode::opMax:
-        outputs = _opCreator.convertElementwise(input_nodes, mir::ops::ElementwiseOp::OpType::max);
+        outputs = _opCreator.convertElementwise(inputs, mir::ops::ElementwiseOp::OpType::max);
         break;
       case ONNXOpCode::opGlobalAveragePool:
       case ONNXOpCode::opAveragePool:
       case ONNXOpCode::opMaxPool:
-        outputs = _opCreator.convertPool(input_nodes, onnx_op_type->opCode, onnx_node);
+        outputs = _opCreator.convertPool(inputs, onnx_op_type->opCode, onnx_node);
         break;
       case ONNXOpCode::opConcat:
-        outputs = _opCreator.convertConcat(input_nodes, onnx_node);
+        outputs = _opCreator.convertConcat(inputs, onnx_node);
         break;
       case ONNXOpCode::opReshape:
-        outputs = _opCreator.convertReshape(input_nodes);
+        outputs = _opCreator.convertReshape(inputs);
         break;
       case ONNXOpCode::opRelu:
-        outputs = _opCreator.convertRelu(input_nodes);
+        outputs = _opCreator.convertRelu(inputs);
         break;
       case ONNXOpCode::opUnsqueeze:
-        outputs = _opCreator.convertUnsqueeze(input_nodes[0], onnx_node);
+        outputs = _opCreator.convertUnsqueeze(inputs, onnx_node);
         break;
       case ONNXOpCode::opSigmoid:
-        outputs = _opCreator.convertSigmoid(input_nodes);
+        outputs = _opCreator.convertSigmoid(inputs);
         break;
       case ONNXOpCode::opSoftmax:
-        outputs = _opCreator.convertSoftmax(input_nodes, onnx_node);
+        outputs = _opCreator.convertSoftmax(inputs, onnx_node);
         break;
       case ONNXOpCode::opScale:
-        outputs = _opCreator.convertScale(input_nodes, onnx_node);
+        outputs = _opCreator.convertScale(inputs, onnx_node);
         break;
       case ONNXOpCode::opBatchNormalization:
-        outputs = _opCreator.convertBatchNorm(input_nodes, onnx_node, _inputTensors);
+        outputs = _opCreator.convertBatchNorm(inputs, onnx_node, _inputTensors);
         break;
       case ONNXOpCode::opDropout:
-        outputs = _opCreator.convertDropout(input_nodes, onnx_node);
+        outputs = _opCreator.convertDropout(inputs, onnx_node);
         break;
       default:
         throw PassException("Invalid ONNXOpCode" + std::to_string((int)onnx_op_type->opCode));
     }
     // Set outputs' names
     for (int i = 0; i < outputs.size(); i++) {
-      outputs[i]->setName(onnx_node.output(i));
-      auto result = _tensorNameToPrevMirOp.emplace(outputs[i]->getName(), outputs[i]);
+      outputs[i].op->setName(onnx_node.output(i));
+      auto result = _tensorNameToIODescriptor.emplace(outputs[i].op->getName(), outputs[i]);
       if(!result.second)
-        throw PassException("Name duplication: " + outputs[i]->getName());
+        throw PassException("Name duplication: " + outputs[i].op->getName());
     }
     assert (outputs.size());
     // FIXME: it should be done properly via the given graph outputs
     _graphOutputs.assign(outputs.begin(), outputs.end());
 #if 0
-    dump(input_nodes, outputs, onnx_node);
+    dump(inputs, outputs, onnx_node);
 #endif
   }
   // set graph outputs
   // TODO: it should be done with onnx graph outputs
   for (auto& output_idx : _graphOutputs)
-    _graph->markOutput(output_idx);
+    _graph->markOutput(output_idx.op);
 
   return _graph;
 }
index 7fd5683..be4ff8a 100644 (file)
@@ -39,7 +39,8 @@ public:
 
   void import() {};
   mir::Graph *createIR() override;
-  void dump(const std::vector<mir::Operation*>& inputs, const std::vector<mir::Operation*>& op,
+  void dump(const std::vector<mir::IODescriptor>& input_descrs,
+            const std::vector<mir::IODescriptor>& out_descrs,
             const onnx::NodeProto& onnx_node);
 
   static void dumpShape(mir::Shape shape) {
@@ -51,10 +52,10 @@ public:
   private:
   void createGraphInputs();
   // This map maps onnx tensor names to MIR operations/nodes
-  std::map<std::string, mir::Operation*> _tensorNameToPrevMirOp;
+  std::map<std::string, mir::IODescriptor> _tensorNameToIODescriptor;
   // This map keeps named tensors used as graph input initializers.
   std::map<std::string, mir::TensorVariant> _inputTensors;
-  std::vector<mir::Operation*> _graphOutputs;
+  std::vector<mir::IODescriptor> _graphOutputs;
   std::string _modelFilename;
   std::unique_ptr<onnx::ModelProto> _model;
   mir::Graph* _graph;
index 32b35f8..d3814cb 100644 (file)
@@ -125,14 +125,15 @@ static void getKernelStridesPadding(const onnx::NodeProto &onnx_node, KernelStri
   }
 };
 
-std::vector<Operation*> ONNXOpCreator::convertConv2D(InputOps& inputs,
-                                                    const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertConv2D(const std::vector<mir::IODescriptor>& inputs,
+                             const onnx::NodeProto& onnx_node) {
   assert(inputs.size() >= 2);
 
   KernelStridesPadding cdata;
   getKernelStridesPadding(onnx_node, cdata);
   // FIXME: It can be non-constant value.
-  auto* in_weights = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]);
+  auto* in_weights = dynamic_cast<mir::ops::ConstantOp*>(inputs[1].op);
   assert(in_weights && "Weights could be a constant tensor only");
   const auto& in_weights_tensor = in_weights->getValue();
   // We should transpose ONNX MCHW to HWOI
@@ -140,105 +141,104 @@ std::vector<Operation*> ONNXOpCreator::convertConv2D(InputOps& inputs,
 
   mir::ops::ConstantOp* input_bias = nullptr;
   if (inputs.size() > 2) {
-    input_bias = dynamic_cast<mir::ops::ConstantOp*>(inputs[2]);
+    input_bias = dynamic_cast<mir::ops::ConstantOp*>(inputs[2].op);
     assert(input_bias && "1D optional bias could be a constant tensor only");
   }
 
-  inputs.resize(1);
-  std::vector<Operation*> outputs;
   // Transpose ONNX NCHW to MIR NHWC
-  auto t_input = convertONNXToMIR(inputs[0]->getOutput(0));
-  outputs = createOp<ops::Conv2DOp>(t_input[0]->getOutput(0), transposed, cdata.strides_shape,
-                                    cdata.padding_before, cdata.padding_after);
+  auto t_input = convertONNXToMIR(inputs[0]);
+  auto result = createOp<ops::Conv2DOp>(t_input, transposed, cdata.strides_shape,
+                                        cdata.padding_before, cdata.padding_after);
   if (input_bias)
-    outputs = createOp<ops::BiasAddOp>(outputs[0]->getOutput(0), input_bias->getValue());
+    result = createOp<ops::BiasAddOp>(result->getOutput(0), input_bias->getValue());
 
-  return convertMIRToONNX(outputs[0]->getOutput(0));
+  return {convertMIRToONNX(result->getOutput(0))};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertConcat(InputOps& inputs,
-                                                    const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertConcat(const std::vector<mir::IODescriptor>& inputs,
+                             const onnx::NodeProto& onnx_node) {
   bool found;
   int axis;
   std::tie (found, axis) = getIntAttribute(onnx_node);
   if (!found)
     throw PassException("Concat must have 'axis' attribute");
-  std::vector<IODescriptor> descriptors;
-  for (auto input : inputs)
-    descriptors.push_back(input->getOutput(0));
-  return createOp<ops::ConcatOp>(descriptors, axis);
+  auto result = createOp<ops::ConcatOp>(inputs, axis);
+  return {result->getOutput(0)};
 }
 
-std::vector<mir::Operation*>
-ONNXOpCreator::convertGather(ONNXOpCreator::InputOps& inputs, const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertGather(const std::vector<mir::IODescriptor>& inputs,
+                             const onnx::NodeProto& onnx_node) {
   bool found;
   int value;
   std::tie(found, value) = getIntAttribute(onnx_node, "axis");
   int axis = found ? value : 0;
-  return createOp<ops::GatherOp>(inputs[0]->getOutput(0), inputs[1]->getOutput(0), axis);
+  auto result = createOp<ops::GatherOp>(inputs[0], inputs[1], axis);
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode op_code,
-                                                   const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertPool(const std::vector<mir::IODescriptor>& inputs,
+                           ONNXOpCode op_code,
+                           const onnx::NodeProto& onnx_node) {
   ops::PoolOp::BorderType border_type;
   ops::PoolOp::PoolingType pool_type;
 
-  std::vector<Operation*> result;
   KernelStridesPadding cdata;
   // Transpose ONNX NCHW to MIR NHWC
-  auto t_input = convertONNXToMIR(inputs[0]->getOutput(0));
+  auto t_input = convertONNXToMIR(inputs[0]);
 
   switch (op_code) {
     case ONNXOpCode::opGlobalAveragePool: {
+      border_type = ops::PoolOp::BorderType::ZEROFILLED;
+      pool_type = ops::PoolOp::PoolingType::AVG;
       // GlobalAveragePool is equivalent to AveragePool with kernel size equal
       // to the spatial dimension of input tensor
-      result = createOp<ops::PoolOp>(t_input[0]->getOutput(0),
-                                     ops::PoolOp::PoolingType::AVG,
-                                     t_input[0]->getOutputShape(0), // kernel_shape
-                                     Shape({1, 1}),                // strides_shape
-                                     cdata.padding_before, cdata.padding_after,
-                                     ops::PoolOp::BorderType::ZEROFILLED,
-                                     ops::PoolOp::RoundMode::floor);
-      return convertMIRToONNX(result[0]->getOutput(0));
+      cdata.kernel_shape = t_input.op->getOutputShape(0);
+      cdata.strides_shape = Shape{1, 1};
+      break;
     }
     case ONNXOpCode::opAveragePool:
       border_type = ops::PoolOp::BorderType::ZEROFILLED;
       pool_type = ops::PoolOp::PoolingType::AVG;
+      getKernelStridesPadding(onnx_node, cdata);
       break;
     case ONNXOpCode::opMaxPool:
       border_type = ops::PoolOp::BorderType::EMPTY;
       pool_type = ops::PoolOp::PoolingType::MAX;
+      getKernelStridesPadding(onnx_node, cdata);
       break;
     default:
       assert(false);
   }
-  // Proceed with Average or Max Pool
-  getKernelStridesPadding(onnx_node, cdata);
 
-  result = createOp<ops::PoolOp>(t_input[0]->getOutput(0), pool_type,
-                                 cdata.kernel_shape, cdata.strides_shape,
-                                 cdata.padding_before, cdata.padding_after,
-                                 border_type,
-                                 ops::PoolOp::RoundMode::floor);
-  return convertMIRToONNX(result[0]->getOutput(0));
+  auto result = createOp<ops::PoolOp>(t_input, pool_type,
+                                      cdata.kernel_shape, cdata.strides_shape,
+                                      cdata.padding_before, cdata.padding_after,
+                                      border_type, ops::PoolOp::RoundMode::floor);
+  return {convertMIRToONNX(result->getOutput(0))};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertSoftmax(InputOps& inputs,
-                                                      const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertSoftmax(const std::vector<mir::IODescriptor>& inputs,
+                              const onnx::NodeProto& onnx_node) {
   int axis;
   bool found;
   std::tie (found, axis) = getIntAttribute(onnx_node);
   axis = found ? axis : 1;
-  return createOp<ops::SoftmaxOp>(inputs[0]->getOutput(0), axis);
+  auto result = createOp<ops::SoftmaxOp>(inputs[0], axis);
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertReshape(InputOps& inputs) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertReshape(const std::vector<mir::IODescriptor>& inputs) {
   // The original shape
-  auto in_shape = inputs[0]->getInputShape(0);
+  auto in_shape = inputs[0].op->getOutputShape(inputs[0].index);
 
   // Input tensor describing the new shape
   // TODO: could it be not a constant?
-  auto* op = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]);
+  auto* op = dynamic_cast<mir::ops::ConstantOp*>(inputs[1].op);
   assert(op && "We support constants only");
   auto shape_tensor = op->getValue();
   Shape shape_tensor_shape = (shape_tensor).getShape();
@@ -261,17 +261,18 @@ std::vector<Operation*> ONNXOpCreator::convertReshape(InputOps& inputs) {
     i++;
   }
   auto out_shape = Shape(shape_vector);
-  auto outputs = createOp<ops::ReshapeOp>(inputs[0]->getOutput(0), out_shape);
-  return outputs;
+  auto result = createOp<ops::ReshapeOp>(inputs[0], out_shape);
+  return {result->getOutput(0)};
 }
 
-std::vector<mir::Operation*>
-ONNXOpCreator::convertUnsqueeze(Operation* input_data, const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertUnsqueeze(const std::vector<mir::IODescriptor>& inputs,
+                                const onnx::NodeProto& onnx_node) {
   auto* axes = findAttribute(onnx_node, "axes");
   assert(axes && axes->ints_size());
-  const int out_rank = input_data->getOutputShape(0).rank() + axes->ints_size();
+  const Shape& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
+  const int out_rank = input_shape.rank() + axes->ints_size();
   Shape out_shape(out_rank);
-  const Shape& input_shape = input_data->getOutputShape(0);
   auto ints_iterator = axes->ints().begin();
   int j = 0;
   for (int i = 0; i < out_rank; i++) {
@@ -283,84 +284,93 @@ ONNXOpCreator::convertUnsqueeze(Operation* input_data, const onnx::NodeProto& on
       j++;
     }
   }
-  auto outputs = createOp<ops::ReshapeOp>(input_data->getOutput(0), out_shape);
-  return outputs;
+  auto result = createOp<ops::ReshapeOp>(inputs[0], out_shape);
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertRelu(InputOps& inputs) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertRelu(const std::vector<mir::IODescriptor>& inputs) {
   assert(inputs.size() == 1);
-  return createOp<ops::ReluOp>(inputs[0]->getOutput(0));
+  auto result = createOp<ops::ReluOp>(inputs[0]);
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertSigmoid(InputOps& inputs) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertSigmoid(const std::vector<mir::IODescriptor>& inputs) {
   assert(inputs.size() == 1);
-  return createOp<ops::SigmoidOp>(inputs[0]->getOutput(0));
+  auto result = createOp<ops::SigmoidOp>(inputs[0]);
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertElementwise(InputOps& inputs,
-                                                         mir::ops::ElementwiseOp::OpType op_type) {
-  std::vector<IODescriptor> descriptors;
-  for (auto input : inputs)
-    descriptors.push_back(input->getOutput(0));
-  return createOp<ops::ElementwiseOp>(descriptors, op_type);
+std::vector<IODescriptor>
+ONNXOpCreator::convertElementwise(const std::vector<mir::IODescriptor>& inputs,
+                                  mir::ops::ElementwiseOp::OpType op_type) {
+  auto result = createOp<ops::ElementwiseOp>(inputs, op_type);
+  return {result->getOutput(0)};
 }
-std::vector<Operation*> ONNXOpCreator::convertBatchNorm(InputOps& inputs,
-                                                        const onnx::NodeProto& onnx_node,
-                                                        InputTensors& input_tensors) {
+
+std::vector<IODescriptor>
+ONNXOpCreator::convertBatchNorm(const std::vector<mir::IODescriptor>& inputs,
+                                const onnx::NodeProto& onnx_node,
+                                InputTensors& input_tensors) {
   // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias
   bool found;
   float value;
   std::tie(found, value) = getFloatAttribute(onnx_node, "epsilon");
   float epsilon = found ? value : 1e-05f;
 
-  const auto& scale = input_tensors.at(inputs[1]->getName());
-  const auto& bias = input_tensors.at(inputs[2]->getName());
-  const auto& mean = input_tensors.at(inputs[3]->getName());
-  const auto& var = input_tensors.at(inputs[4]->getName());
+  const auto& scale = input_tensors.at(inputs[1].op->getName());
+  const auto& bias = input_tensors.at(inputs[2].op->getName());
+  const auto& mean = input_tensors.at(inputs[3].op->getName());
+  const auto& var = input_tensors.at(inputs[4].op->getName());
 
   // res1 = X - mean
   Tensor<float> bias_data(mean);
   for (auto& idx: ShapeRange(bias_data.getShape()))
     bias_data.at(idx) *= -1;
 
-  auto data = convertONNXToMIR(inputs[0]->getOutput(0));
-  auto bias_add_1 = createOp<ops::BiasAddOp>(data[0]->getOutput(0), mean);
+  auto data = convertONNXToMIR(inputs[0]);
+  auto bias_add_1 = createOp<ops::BiasAddOp>(data, mean);
 
   // res2 = res1 * scale / (var + epsilon)
   Tensor<float> multiplier(scale);
   Tensor<float> var_accessor(var);
   for (auto& idx: ShapeRange(scale.getShape()))
     multiplier.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon);
-  auto scale_op = createOp<ops::ScaleOp>(bias_add_1[0]->getOutput(0), scale);
+  auto scale_op = createOp<ops::ScaleOp>(bias_add_1->getOutput(0), scale);
 
   // overall_res = res2 + bias
-  auto bias_add_2 = createOp<ops::BiasAddOp>(scale_op[0]->getOutput(0), bias);
+  auto bias_add_2 = createOp<ops::BiasAddOp>(scale_op->getOutput(0), bias);
 
-  return {convertMIRToONNX(bias_add_2[0]->getOutput(0))};
+  return {convertMIRToONNX(bias_add_2->getOutput(0))};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertDropout(InputOps& inputs,
-                                                      const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertDropout(const std::vector<mir::IODescriptor>& inputs,
+                              const onnx::NodeProto& onnx_node) {
   bool found;
   float value;
   std::tie(found, value) = getFloatAttribute(onnx_node, "ratio");
   float ratio = found ? value : 1.0;
-  return createOp<ops::SoftmaxOp>(inputs[0]->getOutput(0), ratio);
+  auto result = createOp<ops::SoftmaxOp>(inputs[0], ratio);
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertScale(InputOps& inputs,
-                                                   const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertScale(const std::vector<mir::IODescriptor>& inputs,
+                            const onnx::NodeProto& onnx_node) {
   bool found;
   float value;
   std::tie(found, value) = getFloatAttribute(onnx_node, "scale");
   float scale = found ? value : 1.0;
-  auto outputs = createOp<ops::ScaleOp>(inputs[0]->getOutput(0),
-                                        createTensor(scale, inputs[0]->getOutputShape(0)));
-  return outputs;
+  const auto& shape = inputs[0].op->getOutputShape(inputs[0].index);
+  auto result = createOp<ops::ScaleOp>(inputs[0], createTensor(scale, shape));
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertGemm(InputOps& inputs,
-                                                   const onnx::NodeProto& onnx_node) {
+std::vector<IODescriptor>
+ONNXOpCreator::convertGemm(const std::vector<mir::IODescriptor>& inputs,
+                           const onnx::NodeProto& onnx_node) {
   bool  found;
   int   ivalue;
   float fvalue;
@@ -385,57 +395,49 @@ std::vector<Operation*> ONNXOpCreator::convertGemm(InputOps& inputs,
 
   // 1. Prepare input matrix A
   // Flatten the shape by dim(0)
-  mir::Shape shape0 ({inputs[0]->getOutputShape(0).dim(0),
-                      inputs[0]->getOutputShape(0).numElements() /
-                                                        inputs[0]->getOutputShape(0).dim(0)});
-  auto input_a = createOp<ops::ReshapeOp>(inputs[0]->getOutput(0), shape0);
+  const auto& in_shape = inputs[0].op->getOutputShape(inputs[0].index);
+  mir::Shape shape0{in_shape.dim(0), in_shape.numElements() / in_shape.dim(0)};
+  auto input_a = createOp<ops::ReshapeOp>(inputs[0], shape0);
   if (trans_a)
-    input_a = createOp<ops::TransposeOp>(input_a[0]->getOutput(0), std::vector<std::size_t>{1, 0});
+    input_a = createOp<ops::TransposeOp>(input_a->getOutput(0), std::vector<std::size_t>{1, 0});
   if (alpha != 1.0)
-    input_a = createOp<ops::ScaleOp>(input_a[0]->getOutput(0),
-                                          createTensor(alpha, input_a[0]->getOutputShape(0)));
+    input_a = createOp<ops::ScaleOp>(input_a->getOutput(0),
+                                     createTensor(alpha, input_a->getOutputShape(0)));
 
   // 2. Prepare input matrix B
   //
-  auto input_b = inputs[1]->getOutput(0);
+  auto input_b = inputs[1];
   if (trans_b)
-    input_b = createOp<ops::TransposeOp>(input_b, std::vector<std::size_t>{1, 0})[0]->getOutput(0);
+    input_b = createOp<ops::TransposeOp>(input_b, std::vector<std::size_t>{1, 0})->getOutput(0);
   // Number of cols in tensor A must be equal to number of rows in tensor B
-  assert(input_a[0]->getOutput(0).op->getOutputShape(0).dim(1) ==
+  assert(input_a->getOutput(0).op->getOutputShape(0).dim(1) ==
          input_b.op->getOutputShape(0).dim(0));
-  Shape mult_a_b({input_a[0]->getOutput(0).op->getOutputShape(0).dim(0),
-                  input_b.op->getOutputShape(0).dim(1)});
+  Shape mult_a_b({input_a->getOutputShape(0).dim(0),
+                  input_b.op->getOutputShape(input_b.index).dim(1)});
 
   // 3. Prepare input matrix C
   //
-  auto input_c = inputs[2]->getOutput(0);
+  auto input_c = inputs[2];
   auto beta_tensor = createTensor(beta, input_c.op->getOutputShape(0));
   // TODO: check 'broadcast' attribute here
   if ((mult_a_b.rank() == 2) && (input_c.op->getOutputShape(0).rank() == 1)) {
     beta_tensor = TensorVariant(beta_tensor, mult_a_b);
   }
-  auto constant = createOp<ops::ConstantOp>(beta_tensor)[0]->getOutput(0);
+  auto constant = createOp<ops::ConstantOp>(beta_tensor)->getOutput(0);
   std::vector<IODescriptor> descriptors = {constant, input_c};
   auto c_mult = createOp<ops::ElementwiseOp>(descriptors, ops::ElementwiseOp::OpType::mul);
-  assert(c_mult[0]->getOutput(0).op->getOutputShape(0) == mult_a_b);
-  return createOp<ops::GemmOp>(input_a[0]->getOutput(0), input_b, c_mult[0]->getOutput(0));
-}
-
-std::vector<Operation*>
-ONNXOpCreator::createInput(const std::string& input_name, const mir::Shape& input_shape) {
-  // TODO For now we only support convolutional networks with one element per batch.
-  assert(input_shape.rank() == 4 && input_shape.dim(0) == 1);
-  auto variable = _graph->create<ops::VariableOp>(input_name, input_shape);
-  return {variable};
+  assert(c_mult->getOutputShape(0) == mult_a_b);
+  auto result = createOp<ops::GemmOp>(input_a->getOutput(0), input_b, c_mult->getOutput(0));
+  return {result->getOutput(0)};
 }
 
-std::vector<Operation*> ONNXOpCreator::convertONNXToMIR(const mir::IODescriptor& arg) {
+mir::IODescriptor ONNXOpCreator::convertONNXToMIR(mir::IODescriptor arg) {
   // NCHW -> NHWC
-  return createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 2, 3, 1});
+  return createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 2, 3, 1})->getOutput(0);
 }
 
-std::vector<Operation*> ONNXOpCreator::convertMIRToONNX(const mir::IODescriptor& arg) {
+mir::IODescriptor ONNXOpCreator::convertMIRToONNX(mir::IODescriptor arg) {
   // NHWC -> NCHW
-  return createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2});
+  return createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2})->getOutput(0);
 }
 } // namespace nnc
index 30999b5..7ae9144 100644 (file)
@@ -33,46 +33,80 @@ namespace nnc {
 
 class ONNXOpCreator {
 public:
-  using InputOps = std::vector<mir::Operation*>;
   using InputTensors = std::map<std::string, mir::TensorVariant>;
 
   ONNXOpCreator() = default;
-  void setMirGraph(mir::Graph* g) {_graph = g;};
-  std::vector<mir::Operation*> convertConv2D(InputOps& inputs, const onnx::NodeProto& node);
-  std::vector<mir::Operation*> convertConcat(InputOps& inputs, const onnx::NodeProto& onnx_node);
-  std::vector<mir::Operation*> convertPool(InputOps& inputs, ONNXOpCode op_code,
-                                          const onnx::NodeProto& onnx_node);
-  std::vector<mir::Operation*> convertSoftmax(InputOps& inputs, const onnx::NodeProto& onnx_node);
-  std::vector<mir::Operation*> convertReshape(InputOps& inputs);
-  std::vector<mir::Operation*> convertRelu(InputOps& inputs);
-  std::vector<mir::Operation*> convertSigmoid(InputOps& inputs);
-
-  std::vector<mir::Operation*>
-  convertUnsqueeze(mir::Operation* inputs, const onnx::NodeProto& onnx_node);
-  std::vector<mir::Operation*> convertElementwise(InputOps& inputs,
-                                                 mir::ops::ElementwiseOp::OpType op_type);
-  std::vector<mir::Operation*> convertScale(InputOps& inputs, const onnx::NodeProto& node);
-  std::vector<mir::Operation*> convertBatchNorm(InputOps& inputs, const onnx::NodeProto& node,
-                                               InputTensors& input_tensors);
-  std::vector<mir::Operation*> convertDropout(InputOps& inputs, const onnx::NodeProto& onnx_node);
-  std::vector<mir::Operation*> convertGather(InputOps& inputs, const onnx::NodeProto& onnx_node);
-  std::vector<mir::Operation*> convertGemm(InputOps& inputs, const onnx::NodeProto& onnx_node);
-
-  std::vector<mir::Operation*> createInput(const std::string&, const mir::Shape&);
-  std::vector<mir::Operation*> convertONNXToMIR(const mir::IODescriptor& arg);
-  std::vector<mir::Operation*> convertMIRToONNX(const mir::IODescriptor& arg);
+
+  void setMirGraph(mir::Graph* g) { _graph = g; };
+
+  std::vector<mir::IODescriptor>
+  convertConv2D(const std::vector<mir::IODescriptor>& inputs,
+                const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertConcat(const std::vector<mir::IODescriptor>& inputs,
+                const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertPool(const std::vector<mir::IODescriptor>& inputs,
+              ONNXOpCode op_code,
+              const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertSoftmax(const std::vector<mir::IODescriptor>& inputs,
+                 const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertReshape(const std::vector<mir::IODescriptor>& inputs);
+
+  std::vector<mir::IODescriptor>
+  convertRelu(const std::vector<mir::IODescriptor>& inputs);
+
+  std::vector<mir::IODescriptor>
+  convertSigmoid(const std::vector<mir::IODescriptor>& inputs);
+
+  std::vector<mir::IODescriptor>
+  convertUnsqueeze(const std::vector<mir::IODescriptor>& inputs,
+                   const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertElementwise(const std::vector<mir::IODescriptor>& inputs,
+                     mir::ops::ElementwiseOp::OpType op_type);
+
+  std::vector<mir::IODescriptor>
+  convertScale(const std::vector<mir::IODescriptor>& inputs,
+               const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertBatchNorm(const std::vector<mir::IODescriptor>& inputs,
+                   const onnx::NodeProto& onnx_node,
+                   InputTensors& input_tensors);
+
+  std::vector<mir::IODescriptor>
+  convertDropout(const std::vector<mir::IODescriptor>& inputs,
+                 const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertGather(const std::vector<mir::IODescriptor>& inputs,
+                const onnx::NodeProto& onnx_node);
+
+  std::vector<mir::IODescriptor>
+  convertGemm(const std::vector<mir::IODescriptor>& inputs,
+              const onnx::NodeProto& onnx_node);
+
+  mir::IODescriptor convertONNXToMIR(mir::IODescriptor arg);
+  mir::IODescriptor convertMIRToONNX(mir::IODescriptor arg);
 
 private:
   template <typename OpType, typename ...Types>
-  std::vector<nnc::mir::Operation*> createOp(Types&&... args);
+  nnc::mir::Operation* createOp(Types&&... args);
   mir::Graph* _graph = nullptr;
 };
 
 template<typename OpType, typename ...Types>
-std::vector<nnc::mir::Operation*> ONNXOpCreator::createOp(Types&&... args) {
+nnc::mir::Operation* ONNXOpCreator::createOp(Types&&... args) {
   // TODO: set operation names
-  auto op = _graph->create<OpType>("", std::forward<Types>(args)...);
-  return {op};
+  return _graph->create<OpType>("", std::forward<Types>(args)...);
 }
 } // namespace nnc
 #endif //NNCC_ONNX_OP_CREATOR_H