[nnc] Added Unsqueeze to ONNX (#2729)
authorАндрей Шедько/AI Tools Lab /SRR/Engineer/삼성전자 <a.shedko@samsung.com>
Wed, 19 Dec 2018 18:32:09 +0000 (21:32 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Wed, 19 Dec 2018 18:32:09 +0000 (21:32 +0300)
Added Unsqueeze to ONNX with Reshape

Signed-off-by: Andrei Shedko <a.shedko@samsung.com>
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h
contrib/nnc/unittests/core/ShapeInference.cpp

index 79c8e8a..d3e9e64 100644 (file)
@@ -62,6 +62,7 @@ static void collectUnsupportedOps(std::unique_ptr<onnx::ModelProto>& model) {
       case ONNXOpCode::opMul:
       case ONNXOpCode::opRelu:
       case ONNXOpCode::opReshape:
+      case ONNXOpCode::opUnsqueeze:
       case ONNXOpCode::opSigmoid:
       case ONNXOpCode::opScale:
       case ONNXOpCode::opSoftmax:
@@ -287,6 +288,9 @@ mir::Graph *ONNXImporterImpl::createIR() {
       case ONNXOpCode::opRelu:
         outputs = _opCreator.convertRelu(input_nodes);
         break;
+      case ONNXOpCode::opUnsqueeze:
+        outputs = _opCreator.convertUnsqueeze(input_nodes[0], onnx_node);
+        break;
       case ONNXOpCode::opSigmoid:
         outputs = _opCreator.convertSigmoid(input_nodes);
         break;
index aebe6cd..987ff29 100644 (file)
@@ -228,6 +228,28 @@ std::vector<Operation*> ONNXOpCreator::convertReshape(Operation* inputData, Shap
   return outputs;
 }
 
+std::vector<mir::Operation*>
+ONNXOpCreator::convertUnsqueeze(Operation* input_data, const onnx::NodeProto& onnx_node) {
+  auto* axes = findAttribute(onnx_node, "axes");
+  assert(axes && axes->ints_size());
+  const int out_rank = input_data->getOutputShape(0).rank() + axes->ints_size();
+  Shape out_shape(out_rank);
+  const Shape& input_shape = input_data->getOutputShape(0);
+  auto ints_iterator = axes->ints().begin();
+  int j = 0;
+  for (int i = 0; i < out_rank; i++) {
+    if (ints_iterator < axes->ints().end() && i == *ints_iterator) {
+      out_shape.dim(i) = 1;
+      ints_iterator++;
+    } else {
+      out_shape.dim(i) = input_shape.dim(j);
+      j++;
+    }
+  }
+  auto outputs = createOp<ops::ReshapeOp>(input_data->getOutput(0), out_shape);
+  return outputs;
+}
+
 std::vector<Operation*> ONNXOpCreator::convertRelu(InputOps& inputs) {
   assert(inputs.size() == 1);
   return createOp<ops::ReluOp>(inputs[0]->getOutput(0));
index 4dc4650..479492a 100644 (file)
@@ -46,6 +46,9 @@ public:
   std::vector<mir::Operation*> convertReshape(mir::Operation* input_data, mir::Shape output_shape);
   std::vector<mir::Operation*> convertRelu(InputOps& inputs);
   std::vector<mir::Operation*> convertSigmoid(InputOps& inputs);
+
+  std::vector<mir::Operation*>
+  convertUnsqueeze(mir::Operation* inputs, const onnx::NodeProto& onnx_node);
   std::vector<mir::Operation*> convertElementwise(InputOps& inputs,
                                                  mir::ops::ElementwiseOp::OpType op_type);
   std::vector<mir::Operation*> convertScale(InputOps& inputs, const onnx::NodeProto& node);
index ad9f10c..8905bc8 100644 (file)
@@ -107,6 +107,19 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionExpand) {
   ASSERT_EQ(result_shape_expand, op->getOutputShape(0));
 }
 
+TEST(ShapeInferenceTest, ReshapeAutoDimensionUnsqueeze) {
+  Graph g;
+
+  Shape input_shape{10, 2, 10};
+  Shape result_shape_expand{1, 10, 2, 1, 10, 1};
+
+  auto input = g.create<ops::VariableOp>("input", input_shape);
+  auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0),
+                                     Shape{1, Shape::autoDim, 2, 1, 10, 1});
+
+  ASSERT_EQ(result_shape_expand, op->getOutputShape(0));
+}
+
 TEST(ShapeInferenceTest, SqueezeTestAllDims) {
   Graph g;