2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include "ONNXHelpers.h"
20 #include "AttributeHelpers.h"
22 #include "mir/Tensor.h"
23 #include "mir/ShapeRange.h"
25 #include "mir/ops/ConstantOp.h"
26 #include "mir/ops/ReshapeOp.h"
31 void convertReshapeV1(const onnx::NodeProto &onnx_node, ConverterContext *context)
33 std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
34 mir::Graph *graph = context->getGraph();
35 // consumed_inputs attribute not used
36 const auto *shape_attr = findAttribute(onnx_node, "shape");
37 if (shape_attr && shape_attr->ints_size() > 0)
39 mir::Shape in_shape = inputs[0]->getShape();
40 mir::Shape out_shape(shape_attr->ints_size());
41 for (int32_t index = 0; index < out_shape.rank(); index++)
43 const auto dim_value = shape_attr->ints(index);
45 out_shape.dim(index) = in_shape.dim(index);
47 out_shape.dim(index) = dim_value;
50 auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
52 context->setNodeOutputs(onnx_node, {result});
54 else // dimension value is unchanged
56 context->setNodeOutputs(onnx_node, {inputs[0]});
60 void convertReshapeV5(const onnx::NodeProto &onnx_node, ConverterContext *context)
62 std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
63 mir::Graph *graph = context->getGraph();
65 const auto &in_shape = inputs[0]->getShape();
67 // Input tensor describing the new shape
68 auto *op = dynamic_cast<mir::ops::ConstantOp *>(inputs[1]->getNode());
69 assert(op && "We support only constant shape input");
70 auto shape_tensor = op->getValue();
71 mir::Shape shape_tensor_shape = (shape_tensor).getShape();
72 assert(shape_tensor_shape.rank() == 1);
73 // The rank of the new shape
74 auto cnt = shape_tensor_shape.numElements();
75 // The vector to build the new shape from
76 std::vector<int32_t> shape_vector(cnt);
77 mir::ShapeRange out_range(shape_tensor_shape);
78 mir::Tensor<int64_t> tensor_accessor(shape_tensor);
81 for (auto idx : out_range)
83 if (tensor_accessor.at(idx) == 0)
84 shape_vector[i] = in_shape.dim(i);
85 else if (tensor_accessor.at(idx) == -1)
86 shape_vector[i] = mir::Shape::autoDim;
88 shape_vector[i] = tensor_accessor.at(idx);
91 auto out_shape = mir::Shape(shape_vector);
92 auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
94 context->setNodeOutputs(onnx_node, {result});
97 } // namespace mir_onnx