2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ONNXHelpers.h"
18 #include "AttributeHelpers.h"
20 #include "MirInterpreter.h"
21 #include "mir/ops/ConstantOp.h"
23 #include "mir/ShapeRange.h"
24 #include "mir/Tensor.h"
25 #include "mir/TensorVariant.h"
26 #include "mir/Index.h"
31 const int64_t firstUnknownOpset = 13;
33 template <typename T> static mir::Shape constantToShapeT(const mir::TensorVariant &t)
35 const mir::Shape &t_shape = t.getShape();
36 mir::Tensor<T> input(t);
37 if (t_shape.rank() != 1)
38 throw std::runtime_error("only 1-d tensors supported as a shape input");
40 mir::Shape target_shape;
41 std::int32_t rank = t_shape.dim(0);
42 target_shape.resize(rank);
43 for (int i = 0; i < rank; ++i)
44 target_shape.dim(i) = static_cast<std::int32_t>(input.at(mir::Index{i}));
48 mir::Shape constantToShape(const mir::ops::ConstantOp *op)
50 const auto &t = op->getValue();
51 mir::DataType d_type = t.getElementType();
53 if (t.getType().isQuantized())
54 throw std::runtime_error("unsupported data type of shape operator");
58 case mir::DataType::FLOAT32:
59 return constantToShapeT<float>(t);
61 case mir::DataType::FLOAT64:
62 return constantToShapeT<double>(t);
64 case mir::DataType::INT32:
65 return constantToShapeT<int32_t>(t);
67 case mir::DataType::INT64:
68 return constantToShapeT<int64_t>(t);
70 case mir::DataType::UINT8:
71 return constantToShapeT<uint8_t>(t);
74 throw std::runtime_error{"Unknown datatype in constant"};
79 mir::DataType onnxDataTypeToMirDataType(onnx::TensorProto::DataType type)
83 case onnx::TensorProto_DataType_UINT8:
84 return mir::DataType::UINT8;
86 case onnx::TensorProto_DataType_INT32:
87 return mir::DataType::INT32;
89 case onnx::TensorProto_DataType_INT64:
90 return mir::DataType::INT64;
92 case onnx::TensorProto_DataType_DOUBLE:
93 return mir::DataType::FLOAT64;
95 case onnx::TensorProto_DataType_FLOAT:
96 return mir::DataType::FLOAT32;
98 case onnx::TensorProto_DataType_UNDEFINED:
99 throw std::runtime_error{"Undefined input data type not supported"};
102 throw std::runtime_error{"Unsupported tensor element data type"};
106 mir::TensorVariant createTensor(const onnx::TensorProto *tensor)
109 const void *src_data;
110 mir::Shape shape(tensor->dims_size());
111 for (int i = 0; i < tensor->dims_size(); ++i)
113 shape.dim(i) = tensor->dims(i);
116 if (tensor->float_data_size() != 0)
118 assert(tensor->data_type() == onnx::TensorProto::FLOAT);
119 type = mir::DataType::FLOAT32;
120 src_data = tensor->float_data().data();
122 else if (tensor->double_data_size() != 0)
124 assert(tensor->data_type() == onnx::TensorProto::DOUBLE);
125 type = mir::DataType::FLOAT64;
126 src_data = tensor->double_data().data();
128 else if (tensor->int32_data_size() != 0)
130 assert(tensor->data_type() == onnx::TensorProto::INT32);
131 type = mir::DataType::INT32;
132 src_data = tensor->int32_data().data();
134 else if (tensor->int64_data_size() != 0)
136 assert(tensor->data_type() == onnx::TensorProto::INT64);
137 type = mir::DataType::INT64;
138 src_data = tensor->int64_data().data();
140 else if (tensor->has_raw_data())
142 type = onnxDataTypeToMirDataType((onnx::TensorProto_DataType)tensor->data_type());
143 src_data = tensor->raw_data().data();
147 throw std::runtime_error("Invalid data in Proto file, investigate");
150 return mir::TensorVariant({type, shape}, src_data);
153 mir::Operation *foldConstants(mir::Graph *graph, mir::Operation *op)
155 if (op->getType() == mir::Operation::Type::constant ||
156 op->getType() == mir::Operation::Type::input || op->getType() == mir::Operation::Type::output)
158 // don't fold input, output and constant nodes
162 if (op->getNumOutputs() != 1)
164 // this operation either have more than 1 output or none at all
169 std::all_of(op->getInputs().begin(), op->getInputs().end(), [](mir::Operation::Output *out) {
170 return out->getNode()->getType() == mir::Operation::Type::constant;
176 mir_interpreter::MIRInterpreter interpreter;
177 for (mir::Operation::Output *out : op->getInputs())
179 auto *constant = static_cast<mir::ops::ConstantOp *>(out->getNode());
180 interpreter.setTensor(out, constant->getValue());
182 op->accept(&interpreter);
183 const mir::TensorVariant &output = interpreter.getTensor(op->getOutput(0));
185 return graph->create<mir::ops::ConstantOp>(output);
188 } // namespace mir_onnx