Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / mir / src / mir_onnx_importer / ONNXHelpers.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ONNXHelpers.h"
18 #include "AttributeHelpers.h"
19
20 #include "MirInterpreter.h"
21 #include "mir/ops/ConstantOp.h"
22
23 #include "mir/ShapeRange.h"
24 #include "mir/Tensor.h"
25 #include "mir/TensorVariant.h"
26 #include "mir/Index.h"
27
28 namespace mir_onnx
29 {
30
31 const int64_t firstUnknownOpset = 13;
32
33 template <typename T> static mir::Shape constantToShapeT(const mir::TensorVariant &t)
34 {
35   const mir::Shape &t_shape = t.getShape();
36   mir::Tensor<T> input(t);
37   if (t_shape.rank() != 1)
38     throw std::runtime_error("only 1-d tensors supported as a shape input");
39
40   mir::Shape target_shape;
41   std::int32_t rank = t_shape.dim(0);
42   target_shape.resize(rank);
43   for (int i = 0; i < rank; ++i)
44     target_shape.dim(i) = static_cast<std::int32_t>(input.at(mir::Index{i}));
45   return target_shape;
46 }
47
48 mir::Shape constantToShape(const mir::ops::ConstantOp *op)
49 {
50   const auto &t = op->getValue();
51   mir::DataType d_type = t.getElementType();
52
53   if (t.getType().isQuantized())
54     throw std::runtime_error("unsupported data type of shape operator");
55
56   switch (d_type)
57   {
58     case mir::DataType::FLOAT32:
59       return constantToShapeT<float>(t);
60       break;
61     case mir::DataType::FLOAT64:
62       return constantToShapeT<double>(t);
63       break;
64     case mir::DataType::INT32:
65       return constantToShapeT<int32_t>(t);
66       break;
67     case mir::DataType::INT64:
68       return constantToShapeT<int64_t>(t);
69       break;
70     case mir::DataType::UINT8:
71       return constantToShapeT<uint8_t>(t);
72       break;
73     default:
74       throw std::runtime_error{"Unknown datatype in constant"};
75       break;
76   }
77 }
78
79 mir::DataType onnxDataTypeToMirDataType(onnx::TensorProto::DataType type)
80 {
81   switch (type)
82   {
83     case onnx::TensorProto_DataType_UINT8:
84       return mir::DataType::UINT8;
85       break;
86     case onnx::TensorProto_DataType_INT32:
87       return mir::DataType::INT32;
88       break;
89     case onnx::TensorProto_DataType_INT64:
90       return mir::DataType::INT64;
91       break;
92     case onnx::TensorProto_DataType_DOUBLE:
93       return mir::DataType::FLOAT64;
94       break;
95     case onnx::TensorProto_DataType_FLOAT:
96       return mir::DataType::FLOAT32;
97       break;
98     case onnx::TensorProto_DataType_UNDEFINED:
99       throw std::runtime_error{"Undefined input data type not supported"};
100       break;
101     default:
102       throw std::runtime_error{"Unsupported tensor element data type"};
103   }
104 }
105
106 mir::TensorVariant createTensor(const onnx::TensorProto *tensor)
107 {
108   mir::DataType type;
109   const void *src_data;
110   mir::Shape shape(tensor->dims_size());
111   for (int i = 0; i < tensor->dims_size(); ++i)
112   {
113     shape.dim(i) = tensor->dims(i);
114   }
115
116   if (tensor->float_data_size() != 0)
117   {
118     assert(tensor->data_type() == onnx::TensorProto::FLOAT);
119     type = mir::DataType::FLOAT32;
120     src_data = tensor->float_data().data();
121   }
122   else if (tensor->double_data_size() != 0)
123   {
124     assert(tensor->data_type() == onnx::TensorProto::DOUBLE);
125     type = mir::DataType::FLOAT64;
126     src_data = tensor->double_data().data();
127   }
128   else if (tensor->int32_data_size() != 0)
129   {
130     assert(tensor->data_type() == onnx::TensorProto::INT32);
131     type = mir::DataType::INT32;
132     src_data = tensor->int32_data().data();
133   }
134   else if (tensor->int64_data_size() != 0)
135   {
136     assert(tensor->data_type() == onnx::TensorProto::INT64);
137     type = mir::DataType::INT64;
138     src_data = tensor->int64_data().data();
139   }
140   else if (tensor->has_raw_data())
141   {
142     type = onnxDataTypeToMirDataType((onnx::TensorProto_DataType)tensor->data_type());
143     src_data = tensor->raw_data().data();
144   }
145   else
146   {
147     throw std::runtime_error("Invalid data in Proto file, investigate");
148   }
149
150   return mir::TensorVariant({type, shape}, src_data);
151 }
152
153 mir::Operation *foldConstants(mir::Graph *graph, mir::Operation *op)
154 {
155   if (op->getType() == mir::Operation::Type::constant ||
156       op->getType() == mir::Operation::Type::input || op->getType() == mir::Operation::Type::output)
157   {
158     // don't fold input, output and constant nodes
159     return op;
160   }
161
162   if (op->getNumOutputs() != 1)
163   {
164     // this operation either have more than 1 output or none at all
165     return op;
166   }
167
168   bool is_foldable =
169       std::all_of(op->getInputs().begin(), op->getInputs().end(), [](mir::Operation::Output *out) {
170         return out->getNode()->getType() == mir::Operation::Type::constant;
171       });
172
173   if (!is_foldable)
174     return op;
175
176   mir_interpreter::MIRInterpreter interpreter;
177   for (mir::Operation::Output *out : op->getInputs())
178   {
179     auto *constant = static_cast<mir::ops::ConstantOp *>(out->getNode());
180     interpreter.setTensor(out, constant->getValue());
181   }
182   op->accept(&interpreter);
183   const mir::TensorVariant &output = interpreter.getTensor(op->getOutput(0));
184
185   return graph->create<mir::ops::ConstantOp>(output);
186 }
187
188 } // namespace mir_onnx