Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / mir / src / mir_onnx_importer / Op / Reshape.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Reshape.h"
18
19 #include "ONNXHelpers.h"
20 #include "AttributeHelpers.h"
21
22 #include "mir/Tensor.h"
23 #include "mir/ShapeRange.h"
24
25 #include "mir/ops/ConstantOp.h"
26 #include "mir/ops/ReshapeOp.h"
27
28 namespace mir_onnx
29 {
30
31 void convertReshapeV1(const onnx::NodeProto &onnx_node, ConverterContext *context)
32 {
33   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
34   mir::Graph *graph = context->getGraph();
35   // consumed_inputs attribute not used
36   const auto *shape_attr = findAttribute(onnx_node, "shape");
37   if (shape_attr && shape_attr->ints_size() > 0)
38   {
39     mir::Shape in_shape = inputs[0]->getShape();
40     mir::Shape out_shape(shape_attr->ints_size());
41     for (int32_t index = 0; index < out_shape.rank(); index++)
42     {
43       const auto dim_value = shape_attr->ints(index);
44       if (dim_value == 0)
45         out_shape.dim(index) = in_shape.dim(index);
46       else
47         out_shape.dim(index) = dim_value;
48     }
49
50     auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
51
52     context->setNodeOutputs(onnx_node, {result});
53   }
54   else // dimension value is unchanged
55   {
56     context->setNodeOutputs(onnx_node, {inputs[0]});
57   }
58 }
59
60 void convertReshapeV5(const onnx::NodeProto &onnx_node, ConverterContext *context)
61 {
62   std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
63   mir::Graph *graph = context->getGraph();
64   // The original shape
65   const auto &in_shape = inputs[0]->getShape();
66
67   // Input tensor describing the new shape
68   auto *op = dynamic_cast<mir::ops::ConstantOp *>(inputs[1]->getNode());
69   assert(op && "We support only constant shape input");
70   auto shape_tensor = op->getValue();
71   mir::Shape shape_tensor_shape = (shape_tensor).getShape();
72   assert(shape_tensor_shape.rank() == 1);
73   // The rank of the new shape
74   auto cnt = shape_tensor_shape.numElements();
75   // The vector to build the new shape from
76   std::vector<int32_t> shape_vector(cnt);
77   mir::ShapeRange out_range(shape_tensor_shape);
78   mir::Tensor<int64_t> tensor_accessor(shape_tensor);
79
80   int i = 0;
81   for (auto idx : out_range)
82   {
83     if (tensor_accessor.at(idx) == 0)
84       shape_vector[i] = in_shape.dim(i);
85     else if (tensor_accessor.at(idx) == -1)
86       shape_vector[i] = mir::Shape::autoDim;
87     else
88       shape_vector[i] = tensor_accessor.at(idx);
89     i++;
90   }
91   auto out_shape = mir::Shape(shape_vector);
92   auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
93
94   context->setNodeOutputs(onnx_node, {result});
95 }
96
97 } // namespace mir_onnx