2 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __BASE_LOADER_BASE_LOADER_H__
19 #define __BASE_LOADER_BASE_LOADER_H__
23 #include "ir/Operations.Include.h"
25 #include "flatbuffers/flexbuffers.h"
35 #include <util/logging.h>
42 template <typename LoaderDomain> class BaseLoader
45 using Verifier = typename LoaderDomain::Verifier;
46 using ActivationFunctionType = typename LoaderDomain::ActivationFunctionType;
47 using Buffer = typename LoaderDomain::Buffer;
48 using BuiltinOperator = typename LoaderDomain::BuiltinOperator;
49 using CustomOptionsFormat = typename LoaderDomain::CustomOptionsFormat;
50 using Model = typename LoaderDomain::Model;
51 using Operator = typename LoaderDomain::Operator;
52 using Padding = typename LoaderDomain::Padding;
53 using Pool2DOptions = typename LoaderDomain::Pool2DOptions;
54 using SubGraph = typename LoaderDomain::SubGraph;
55 using Tensor = typename LoaderDomain::Tensor;
56 using TensorType = typename LoaderDomain::TensorType;
57 using DimensionType = typename LoaderDomain::DimensionType;
58 using SparseIndexVector = typename LoaderDomain::SparseIndexVector;
61 bool isOptionalInputTensor(std::int32_t idx) { return idx == -1; }
62 virtual bool allowOptionalInputTensor(BuiltinOperator) = 0;
66 * @brief Construct a new Loader object
68 * @param graph reference on subgraphs
70 explicit BaseLoader(std::unique_ptr<ir::Subgraphs> &subgs)
71 : _base{nullptr}, _pagesize(getpagesize()), _fd(-1), _subgraphs(subgs), _model{nullptr},
72 _tensor_names(std::make_shared<std::unordered_map<ir::OperandIndex, std::string>>())
74 _use_mmaped_data = util::getConfigBool(util::config::USE_MMAPED_DATA);
78 * @brief Load a model from file
82 void loadFromFile(const std::string &file_path);
84 * @brief Load a model from a buffer
86 * @param buffer buffer pointer
87 * @param size buffer size
89 void loadFromBuffer(uint8_t *buffer, size_t size);
92 ~BaseLoader() = default;
96 ir::Activation convertActivation(ActivationFunctionType type);
97 ir::DataType tensorTypeToDataType(TensorType type);
98 ir::OperandIndex tensorIdxToOperandIdx(int32_t tensorIdx);
100 // Create operands form tflite::Tensor
101 ir::OperandIndex loadOperand(const Tensor *tensor, ir::Graph &subg);
102 void loadQuantization(const Tensor *tensor, ir::TypeInfo &typeInfo);
103 void loadSparsity(const Tensor *tensor, ir::TypeInfo &typeInfo);
104 void loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs,
105 ir::OperandIndexSequence &outputs);
106 // Create operations from Operator
107 void loadOperation(const Operator *op, ir::Graph &subg);
108 // Load Strides and Paddings from options to param
109 template <typename Param, typename OptionsType>
110 void loadStridesAndPaddings(Param ¶m, const OptionsType *options);
112 template <typename Param> void loadPool2DOptions(Param ¶m, const Pool2DOptions *options);
115 virtual std::unique_ptr<ir::Graph> loadSubgraph(const SubGraph *subg) = 0;
117 template <typename OpIR, typename... Args>
118 const OpIR *loadOperationTo(const Operator *op, ir::Graph &subg, Args &&... args);
120 void loadAddV2(const Operator *op, ir::Graph &subg);
121 void loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax);
122 void loadBatchMatMul(const Operator *op, ir::Graph &subg);
123 void loadBinaryArithmetic(const Operator *op, ir::Graph &subg,
124 ir::operation::BinaryArithmetic::ArithmeticType op_type);
125 void loadComparison(const Operator *op, ir::Graph &subg);
126 void loadConcatenation(const Operator *op, ir::Graph &subg);
127 void loadConv2D(const Operator *op, ir::Graph &subg);
128 void loadCustom(const Operator *op, ir::Graph &subg);
129 void loadDepthToSpace(const Operator *op, ir::Graph &subg);
130 void loadDepthwiseConv2D(const Operator *op, ir::Graph &subg);
131 void loadEinsum(const Operator *op, ir::Graph &subg);
132 void loadElementwiseActivation(const Operator *op, ir::Graph &subg,
133 ir::operation::ElementwiseActivation::Type op_type,
134 float alpha = 0.f, float beta = 0.f);
135 void loadElementwiseBinary(const Operator *op, ir::Graph &subg,
136 ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type);
137 void loadElementwiseUnary(const Operator *op, ir::Graph &subg,
138 ir::operation::ElementwiseUnary::Type op_type);
139 void loadFC(const Operator *op, ir::Graph &subg);
140 void loadFusedBatchNorm(const Operator *op, ir::Graph &subg);
141 void loadGather(const Operator *op, ir::Graph &subg);
142 void loadIf(const Operator *op, ir::Graph &subg);
143 void loadLeakyRelu(const Operator *op, ir::Graph &subg);
144 void loadLogSoftmax(const Operator *op, ir::Graph &subg);
145 void loadDetectionPostProcess(const Operator *op, ir::Graph &subg);
146 void loadOneHot(const Operator *op, ir::Graph &subg);
147 void loadPack(const Operator *op, ir::Graph &subg);
148 void loadPool2D(const Operator *op, ir::Graph &subg, ir::operation::Pool2D::PoolType op_type);
149 void loadReduce(const Operator *op, ir::Graph &subg,
150 ir::operation::Reduce::ReduceType reduce_type);
151 void loadReduceAll(const Operator *op, ir::Graph &subg);
152 void loadReshape(const Operator *op, ir::Graph &subg);
153 void loadResizeBilinear(const Operator *op, ir::Graph &subg);
154 void loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg);
155 void loadSoftmax(const Operator *op, ir::Graph &subg);
156 void loadSpaceToDepth(const Operator *op, ir::Graph &subg);
157 void loadSplit(const Operator *op, ir::Graph &subg);
158 void loadSplitV(const Operator *op, ir::Graph &subg);
159 void loadSqueeze(const Operator *op, ir::Graph &subg);
160 void loadStridedSlice(const Operator *op, ir::Graph &subg);
161 void loadTransposeConv(const Operator *op, ir::Graph &subg);
162 void loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg);
163 void loadUnpack(const Operator *op, ir::Graph &subg);
164 void loadWhile(const Operator *op, ir::Graph &subg);
166 void verifySubgraphIndex(int subg_index)
168 const auto num_subgraphs = _model->subgraphs()->size();
169 if (subg_index < 0 || subg_index >= static_cast<int32_t>(num_subgraphs))
170 throw std::runtime_error{std::string{"Invalid subgraph index - "} +
171 std::to_string(subg_index)};
175 // Base address for mapped region for loading (if needed)
179 // loaded file description
181 // Reference on loadable subgraphs
182 std::unique_ptr<ir::Subgraphs> &_subgraphs;
184 // Maps Tensor indices to onert Operands.
185 std::vector<ir::OperandIndex> _tensor_to_operand;
186 std::shared_ptr<std::unordered_map<ir::OperandIndex, std::string>> _tensor_names;
188 std::unique_ptr<Verifier> _verifier;
189 // Boolean flag to use MMAPED_DATA
190 bool _use_mmaped_data = false;
192 std::unordered_map<uint32_t /* Buffer Index in circle file */, std::shared_ptr<ir::Data>>
196 template <typename LoaderDomain>
197 void BaseLoader<LoaderDomain>::BaseLoader::loadFromFile(const std::string &file_path)
199 _fd = open(file_path.c_str(), O_RDONLY);
202 throw std::runtime_error("Failed to open file " + file_path);
205 struct stat file_stat;
206 if (fstat(_fd, &file_stat) != 0)
208 throw std::runtime_error("Fstat failed or file " + file_path + " is not a regular file");
210 int size = file_stat.st_size;
212 // Map model file into memory region
213 _base = static_cast<uint8_t *>(mmap(NULL, size, PROT_READ, MAP_PRIVATE, _fd, 0));
214 if (_base == MAP_FAILED)
217 throw std::runtime_error("mmap failed - " + std::string(strerror(errno)));
220 _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size);
228 template <typename LoaderDomain>
229 void BaseLoader<LoaderDomain>::BaseLoader::loadFromBuffer(uint8_t *buffer, size_t size)
232 _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size);
236 template <typename LoaderDomain>
238 BaseLoader<LoaderDomain>::BaseLoader::convertActivation(const ActivationFunctionType type)
242 case ActivationFunctionType::ActivationFunctionType_NONE:
243 return ir::Activation::NONE;
244 case ActivationFunctionType::ActivationFunctionType_RELU:
245 return ir::Activation::RELU;
246 case ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
247 return ir::Activation::RELU1;
248 case ActivationFunctionType::ActivationFunctionType_RELU6:
249 return ir::Activation::RELU6;
250 case ActivationFunctionType::ActivationFunctionType_TANH:
251 return ir::Activation::TANH;
253 throw std::runtime_error(std::string("Unsupported or invalid activation type: ") +
254 std::to_string(static_cast<int>(type)));
258 template <typename LoaderDomain>
259 ir::DataType BaseLoader<LoaderDomain>::BaseLoader::tensorTypeToDataType(const TensorType type)
263 case TensorType::TensorType_FLOAT32:
264 return ir::DataType::FLOAT32;
265 case TensorType::TensorType_FLOAT16:
266 return ir::DataType::FLOAT16;
267 case TensorType::TensorType_INT32:
268 return ir::DataType::INT32;
269 case TensorType::TensorType_UINT8:
270 return ir::DataType::QUANT_UINT8_ASYMM;
271 case TensorType::TensorType_INT64:
272 return ir::DataType::INT64;
273 // case TensorType::TensorType_STRING:
274 case TensorType::TensorType_BOOL:
275 return ir::DataType::BOOL8;
276 case TensorType::TensorType_INT16:
277 return ir::DataType::QUANT_INT16_ASYMM;
278 // case TensorType::TensorType_COMPLEX64
279 case TensorType::TensorType_INT8:
280 return ir::DataType::QUANT_INT8_ASYMM;
281 // case TensorType::TensorType_FLOAT64
283 throw std::runtime_error(
284 std::string("Unsupported tensor type: ").append(EnumNameTensorType(type)));
288 template <typename LoaderDomain>
289 ir::OperandIndex BaseLoader<LoaderDomain>::BaseLoader::tensorIdxToOperandIdx(int32_t tensorIdx)
291 return isOptionalInputTensor(tensorIdx) ? ir::OperandIndex() : _tensor_to_operand[tensorIdx];
294 /* Copy is copied from tensorflow lite */
295 template <typename T> bool Copy(const T *data_ptr, std::vector<uint16_t> &arr)
297 if (data_ptr->values() == nullptr)
302 int size = data_ptr->values()->size();
304 for (int i = 0; i < size; i++)
306 arr.emplace_back(static_cast<uint16_t>(data_ptr->values()->Get(i)));
311 template <typename LoaderDomain>
312 ir::OperandIndex BaseLoader<LoaderDomain>::loadOperand(const Tensor *tensor, ir::Graph &subg)
316 const auto *tensor_shape = tensor->shape();
317 if (tensor_shape != nullptr)
319 for (const auto &dim : *tensor_shape)
325 // Note for tensor->shape_signature()
326 // We don't handle shape signature
328 // If shape_signature[k] == -1, we will use tensor->shape()[k] == 1
329 // If app wants to change the input shape, call nnfw_apply_input_tensorinfo() can
333 ir::TypeInfo type_info(tensorTypeToDataType(tensor->type()));
334 loadQuantization(tensor, type_info);
335 loadSparsity(tensor, type_info);
338 const auto operand_index = subg.addOperand(shape, type_info);
340 // Constant tensors are indicated by non-empty data.
341 const auto *data = _model->buffers()->Get(tensor->buffer())->data();
344 using std::ptrdiff_t;
345 std::shared_ptr<ir::Data> data_obj;
347 if (_fd == -1) // Model is from memory
349 data_obj = std::make_shared<ir::ExternalData>(data->data(), data->size());
351 else // Model is loaded(mmap'd) from a file
353 size_t data_size = data->size();
354 ptrdiff_t unaligned_offset_start = data->data() - _base;
355 ptrdiff_t offset_end = unaligned_offset_start + data_size;
357 // Calculated aligned offset from base address of mapped region
358 // munmap accepts memory address which is a multiple of the pagesize
359 ptrdiff_t aligned_offset_start = (unaligned_offset_start / _pagesize) * _pagesize;
360 size_t mmap_size = offset_end - aligned_offset_start;
362 uint32_t buf_idx = tensor->buffer();
363 auto buffer_found = _buf_to_data.find(buf_idx);
365 if (buffer_found != _buf_to_data.end())
367 // Another tensor points this buffer and its matching Data(either CachedData or MMapedData)
368 // was already created. Let's reuse the Data
369 data_obj = buffer_found->second;
371 else if (_use_mmaped_data)
373 data_obj = std::make_shared<ir::MMapedData>(_fd, aligned_offset_start, mmap_size,
374 unaligned_offset_start, data_size);
375 _buf_to_data[buf_idx] = data_obj;
379 size_t offset = unaligned_offset_start - aligned_offset_start;
380 uint8_t *mmap_base = static_cast<uint8_t *>(
381 mmap(NULL, mmap_size, PROT_READ, MAP_PRIVATE, _fd, aligned_offset_start));
383 data_obj = std::make_shared<ir::CachedData>(mmap_base + offset, data_size);
384 _buf_to_data[buf_idx] = data_obj;
386 munmap(mmap_base, mmap_size);
389 subg.setOperandValue(operand_index, std::move(data_obj));
392 _tensor_names->emplace(operand_index, tensor->name()->str());
395 if (tensor->is_variable())
398 throw std::runtime_error("Variable tensor with buffer is not supported!");
400 subg.operands().at(operand_index).info().setAsVariable();
403 return operand_index;
406 template <typename LoaderDomain>
407 void BaseLoader<LoaderDomain>::loadQuantization(const Tensor *tensor, ir::TypeInfo &typeInfo)
409 auto q_params = tensor->quantization();
410 if (q_params == nullptr || q_params->scale() == nullptr || q_params->scale()->size() == 0)
412 typeInfo.quantization(0., 0);
415 if (q_params->zero_point() == nullptr)
417 throw std::runtime_error("Quantization params: scale is not null, but zero_point is null.");
419 const size_t num_scales = q_params->scale()->size();
420 if (num_scales != q_params->zero_point()->size())
422 throw std::runtime_error("Quantization params: scale size != zero_point size");
424 std::vector<float> scales;
425 std::vector<int32_t> zero_points;
426 scales.resize(num_scales);
427 zero_points.resize(num_scales);
428 for (size_t i = 0; i < num_scales; ++i)
430 scales[i] = q_params->scale()->Get(i);
431 // zero_point is defined as long (i64) in schema while TypeInfo's zero_point is int32_t.
432 // int64_t is used instead of long because long is 4 byte in most 32bit architecture.
433 int64_t zero_point = q_params->zero_point()->Get(i);
434 if (zero_point < std::numeric_limits<int32_t>::min() ||
435 zero_point > std::numeric_limits<int32_t>::max())
436 throw std::runtime_error("Zero_point is out of int32 range.");
437 zero_points[i] = static_cast<int32_t>(zero_point);
439 auto details = q_params->details_as_CustomQuantization();
440 if (details != nullptr)
441 throw std::runtime_error("Custom Quantization is not supported");
442 typeInfo.quantization(std::move(scales), std::move(zero_points));
445 template <typename LoaderDomain>
446 void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, ir::TypeInfo &typeInfo)
448 auto src_sparsity = tensor->sparsity();
449 if (src_sparsity != nullptr)
451 std::vector<uint16_t> w1_segments;
452 std::vector<uint16_t> w1_indices;
453 // check traversal_order
454 if (src_sparsity->traversal_order())
456 const int traversal_order_size = src_sparsity->traversal_order()->size();
457 for (int i = 0; i < traversal_order_size; ++i)
459 if (i != src_sparsity->traversal_order()->Get(i))
460 throw std::runtime_error("traversal_order [0, 1, ..., n-1] is only supported.");
465 if (src_sparsity->block_map())
467 block_rank = src_sparsity->block_map()->size();
468 for (int i = 0; i < block_rank; ++i)
470 if (i != src_sparsity->block_map()->Get(i))
471 throw std::runtime_error("block_map [0, 1, ..., n-1] is only supported.");
475 const auto dim_metadata_size = src_sparsity->dim_metadata()->size();
476 const auto dense_rank = tensor->shape() ? tensor->shape()->size() : 0;
477 if (dense_rank + block_rank != dim_metadata_size)
478 throw std::runtime_error("sparsity dim_metadata length is wrong.");
479 bool random_sparsity = dim_metadata_size == 2 && block_rank == 0;
480 bool block2D_sparsity = dim_metadata_size == 4 && block_rank == 2;
481 if (dim_metadata_size != !random_sparsity && !block2D_sparsity)
482 throw std::runtime_error(
483 "sparsity is supported only for 2D tensor with random or 16x1 block sparsity.");
485 const auto *src_metadata = src_sparsity->dim_metadata()->Get(0);
486 if (src_metadata->format() != DimensionType::DimensionType_DENSE)
487 throw std::runtime_error("sparse tensor dim[0] is not DENSE");
488 src_metadata = src_sparsity->dim_metadata()->Get(1);
489 if (src_metadata->format() != DimensionType::DimensionType_SPARSE_CSR)
490 throw std::runtime_error("sparse tensor dim[0] is not SPARSE_CSR");
491 auto ParseSparseIndexVector = [src_metadata, &w1_segments, &w1_indices]() {
492 if (src_metadata->array_segments() == nullptr || src_metadata->array_indices() == nullptr)
495 switch (src_metadata->array_segments_type())
497 case SparseIndexVector::SparseIndexVector_Int32Vector:
498 status = Copy(src_metadata->array_segments_as_Int32Vector(), w1_segments);
500 case SparseIndexVector::SparseIndexVector_Uint16Vector:
501 status = Copy(src_metadata->array_segments_as_Uint16Vector(), w1_segments);
503 case SparseIndexVector::SparseIndexVector_Uint8Vector:
504 status = Copy(src_metadata->array_segments_as_Uint8Vector(), w1_segments);
511 switch (src_metadata->array_indices_type())
513 case SparseIndexVector::SparseIndexVector_Int32Vector:
514 return Copy(src_metadata->array_indices_as_Int32Vector(), w1_indices);
515 case SparseIndexVector::SparseIndexVector_Uint16Vector:
516 return Copy(src_metadata->array_indices_as_Uint16Vector(), w1_indices);
517 case SparseIndexVector::SparseIndexVector_Uint8Vector:
518 return Copy(src_metadata->array_indices_as_Uint8Vector(), w1_indices);
524 if (ParseSparseIndexVector() == false)
525 throw std::runtime_error("Error during parsing sparsity index information");
527 std::vector<int32_t> block_size;
528 for (int i = 0; i < block_rank; ++i)
530 auto block_metadata = src_sparsity->dim_metadata()->Get(dense_rank + i);
531 if (block_metadata->format() != DimensionType::DimensionType_DENSE)
532 throw std::runtime_error("block dimension must be DENSE.");
533 block_size.push_back(block_metadata->dense_size());
535 typeInfo.sparsity(std::make_shared<ir::Sparsity>(std::move(w1_segments), std::move(w1_indices),
536 std::move(block_size)));
540 template <typename LoaderDomain>
541 void BaseLoader<LoaderDomain>::loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs,
542 ir::OperandIndexSequence &outputs)
544 for (const std::int32_t idx : *op->inputs())
546 // Optional tensors are not supported yet except for FULLY_CONNECTED and BCQ_FULLY_CONNECTED
547 auto check_optional_input = [&]() {
548 auto builtin_code = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
549 if (isOptionalInputTensor(idx) && !allowOptionalInputTensor(builtin_code))
550 throw std::runtime_error(
551 std::string("loader doesn't support optional input tensor yet for ")
552 .append(EnumNameBuiltinOperator(builtin_code)));
554 check_optional_input();
555 inputs.append(tensorIdxToOperandIdx(idx));
558 for (const std::int32_t idx : *op->outputs())
560 outputs.append(tensorIdxToOperandIdx(idx));
564 template <typename LoaderDomain>
565 template <typename Param, typename OptionsType>
566 void BaseLoader<LoaderDomain>::loadStridesAndPaddings(Param ¶m, const OptionsType *options)
569 param.stride.vertical = options->stride_h();
570 param.stride.horizontal = options->stride_w();
572 switch (options->padding())
574 case Padding::Padding_SAME:
575 param.padding.type = ir::PaddingType::SAME;
577 case Padding::Padding_VALID:
578 param.padding.type = ir::PaddingType::VALID;
581 throw std::runtime_error{"Invalid padding type"};
583 // param paddings indexes unused
586 template <typename LoaderDomain>
587 template <typename Param>
588 void BaseLoader<LoaderDomain>::loadPool2DOptions(Param ¶m, const Pool2DOptions *options)
590 // Strides and Paddings
591 if (options->stride_h() <= 0 || options->stride_w() <= 0)
592 throw std::runtime_error{"Invalid stride vertical or horizontal - both must be bigger than 0"};
593 loadStridesAndPaddings(param, options);
594 // Filter width and height
596 if (options->filter_width() <= 0 || options->filter_height() <= 0)
597 throw std::runtime_error{"Invalid filter width or height - both must be bigger than 0"};
598 param.kw = options->filter_width();
599 param.kh = options->filter_height();
601 param.activation = convertActivation(options->fused_activation_function());
604 template <typename LoaderDomain>
605 template <typename OpIR, typename... Args>
606 const OpIR *BaseLoader<LoaderDomain>::loadOperationTo(const Operator *op, ir::Graph &subg,
609 static_assert(sizeof...(args) <= 1, "You can't have more than 1 arguments!");
610 ir::OperandIndexSequence inputs;
611 ir::OperandIndexSequence outputs;
613 loadOperationIO(op, inputs, outputs);
615 std::unique_ptr<OpIR> new_op(new OpIR(inputs, outputs, std::forward<Args>(args)...));
616 auto ret = new_op.get();
617 subg.addOperation(std::move(new_op));
622 template <typename LoaderDomain>
623 void BaseLoader<LoaderDomain>::loadConv2D(const Operator *op, ir::Graph &subg)
625 ir::operation::Conv2D::Param param;
626 const auto *options = op->builtin_options_as_Conv2DOptions();
627 param.activation = convertActivation(options->fused_activation_function());
628 loadStridesAndPaddings(param, options);
629 param.dilation.width_factor = options->dilation_w_factor();
630 param.dilation.height_factor = options->dilation_h_factor();
632 loadOperationTo<ir::operation::Conv2D>(op, subg, param);
635 template <typename LoaderDomain>
636 void BaseLoader<LoaderDomain>::loadDepthwiseConv2D(const Operator *op, ir::Graph &subg)
638 ir::operation::DepthwiseConv2D::Param param;
639 const auto *options = op->builtin_options_as_DepthwiseConv2DOptions();
640 param.activation = convertActivation(options->fused_activation_function());
641 loadStridesAndPaddings(param, options);
642 param.multiplier = options->depth_multiplier();
643 // Dilation h/w factor unused
644 param.dilation.width_factor = options->dilation_w_factor();
645 param.dilation.height_factor = options->dilation_h_factor();
647 loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param);
650 template <typename LoaderDomain>
651 void BaseLoader<LoaderDomain>::loadTransposeConv(const Operator *op, ir::Graph &subg)
653 ir::operation::TransposeConv::Param param;
654 const auto *options = op->builtin_options_as_TransposeConvOptions();
655 loadStridesAndPaddings(param, options);
657 loadOperationTo<ir::operation::TransposeConv>(op, subg, param);
660 template <typename LoaderDomain>
661 void BaseLoader<LoaderDomain>::loadPool2D(const Operator *op, ir::Graph &subg,
662 ir::operation::Pool2D::PoolType op_type)
664 ir::operation::Pool2D::Param param;
665 param.op_type = op_type;
666 const auto *options = op->builtin_options_as_Pool2DOptions();
668 loadPool2DOptions(param, options);
670 loadOperationTo<ir::operation::Pool2D>(op, subg, param);
673 template <typename LoaderDomain>
674 void BaseLoader<LoaderDomain>::loadReshape(const Operator *op, ir::Graph &subg)
676 ir::operation::Reshape::Param param{};
677 const auto *options = op->builtin_options_as_ReshapeOptions();
678 if (options != nullptr)
680 const auto *new_shape = options->new_shape();
683 for (uint i = 0; i < new_shape->size(); ++i)
685 param.new_shape.push_back(new_shape->Get(i));
690 loadOperationTo<ir::operation::Reshape>(op, subg, param);
693 template <typename LoaderDomain>
694 void BaseLoader<LoaderDomain>::loadSoftmax(const Operator *op, ir::Graph &subg)
696 ir::operation::Softmax::Param param;
697 const auto *options = op->builtin_options_as_SoftmaxOptions();
699 param.beta = options->beta();
701 loadOperationTo<ir::operation::Softmax>(op, subg, param);
704 template <typename LoaderDomain>
705 void BaseLoader<LoaderDomain>::loadConcatenation(const Operator *op, ir::Graph &subg)
707 ir::operation::Concat::Param param;
708 const auto *options = op->builtin_options_as_ConcatenationOptions();
710 param.axis = options->axis();
713 loadOperationTo<ir::operation::Concat>(op, subg, param);
716 template <typename LoaderDomain>
717 void BaseLoader<LoaderDomain>::loadFC(const Operator *op, ir::Graph &subg)
719 ir::operation::FullyConnected::Param param;
720 const auto *options = op->builtin_options_as_FullyConnectedOptions();
722 param.activation = convertActivation(options->fused_activation_function());
723 param.weights_format = static_cast<ir::FullyConnectedWeightsFormat>(options->weights_format());
725 const auto fc = loadOperationTo<ir::operation::FullyConnected>(op, subg, param);
727 const auto &input_operand =
728 subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::INPUT));
729 auto &weights_operand =
730 subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::WEIGHT));
731 if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 &&
732 ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) ||
733 weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM))
735 weights_operand.type(ir::DataType::QUANT_INT8_SYMM);
739 template <typename LoaderDomain>
740 void BaseLoader<LoaderDomain>::loadAddV2(const Operator *op, ir::Graph &subg)
742 ir::operation::BinaryArithmetic::Param param;
743 param.arithmetic_type = ir::operation::BinaryArithmetic::ArithmeticType::ADD;
745 if (op->custom_options() == nullptr)
747 param.activation = ir::Activation::NONE;
751 size_t custom_op_data_size = op->custom_options()->size();
752 auto custom_op_data = op->custom_options()->Data();
753 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
754 auto attr_map = data_root.AsMap();
755 const auto fused_activation_func = static_cast<typename LoaderDomain::ActivationFunctionType>(
756 attr_map["fused_activation_function"].AsInt8());
757 param.activation = convertActivation(fused_activation_func);
760 loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param);
763 template <typename LoaderDomain>
764 void BaseLoader<LoaderDomain>::loadDepthToSpace(const Operator *op, ir::Graph &subg)
766 ir::operation::DepthToSpace::Param param;
767 const auto *options = op->builtin_options_as_DepthToSpaceOptions();
768 param.block_size = options->block_size();
770 loadOperationTo<ir::operation::DepthToSpace>(op, subg, param);
773 template <typename LoaderDomain>
774 void BaseLoader<LoaderDomain>::loadBinaryArithmetic(
775 const Operator *op, ir::Graph &subg, ir::operation::BinaryArithmetic::ArithmeticType op_type)
777 ir::operation::BinaryArithmetic::Param param;
778 param.arithmetic_type = op_type;
781 case ir::operation::BinaryArithmetic::ArithmeticType::ADD:
783 const auto *add_options = op->builtin_options_as_AddOptions();
784 param.activation = convertActivation(add_options->fused_activation_function());
787 case ir::operation::BinaryArithmetic::ArithmeticType::SUB:
789 const auto *sub_options = op->builtin_options_as_SubOptions();
790 param.activation = convertActivation(sub_options->fused_activation_function());
793 case ir::operation::BinaryArithmetic::ArithmeticType::MUL:
795 const auto *mul_options = op->builtin_options_as_MulOptions();
796 param.activation = convertActivation(mul_options->fused_activation_function());
799 case ir::operation::BinaryArithmetic::ArithmeticType::DIV:
801 const auto *div_options = op->builtin_options_as_DivOptions();
802 param.activation = convertActivation(div_options->fused_activation_function());
807 "The function 'loadBinaryArithmetic' supports only BinaryArithmetic operations");
811 loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param);
814 template <typename LoaderDomain>
815 void BaseLoader<LoaderDomain>::loadPack(const Operator *op, ir::Graph &subg)
817 ir::operation::Pack::Param param;
818 const auto *options = op->builtin_options_as_PackOptions();
819 param.num = options->values_count();
820 param.axis = options->axis();
822 loadOperationTo<ir::operation::Pack>(op, subg, param);
825 template <typename LoaderDomain>
826 void BaseLoader<LoaderDomain>::loadElementwiseActivation(
827 const Operator *op, ir::Graph &subg, ir::operation::ElementwiseActivation::Type op_type,
828 float alpha, float beta)
830 ir::operation::ElementwiseActivation::Param param;
831 param.op_type = op_type;
835 loadOperationTo<ir::operation::ElementwiseActivation>(op, subg, param);
838 template <typename LoaderDomain>
839 void BaseLoader<LoaderDomain>::loadResizeBilinear(const Operator *op, ir::Graph &subg)
841 ir::operation::ResizeBilinear::Param param;
842 param.align_corners = op->builtin_options_as_ResizeBilinearOptions()->align_corners();
843 param.half_pixel_centers = op->builtin_options_as_ResizeBilinearOptions()->half_pixel_centers();
845 loadOperationTo<ir::operation::ResizeBilinear>(op, subg, param);
848 template <typename LoaderDomain>
849 void BaseLoader<LoaderDomain>::loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg)
851 ir::operation::ResizeNearestNeighbor::Param param;
852 param.align_corners = op->builtin_options_as_ResizeNearestNeighborOptions()->align_corners();
854 loadOperationTo<ir::operation::ResizeNearestNeighbor>(op, subg, param);
857 template <typename LoaderDomain>
858 void BaseLoader<LoaderDomain>::loadReduce(const Operator *op, ir::Graph &subg,
859 ir::operation::Reduce::ReduceType reduce_type)
861 ir::operation::Reduce::Param param;
862 param.reduce_type = reduce_type;
863 param.keep_dims = op->builtin_options_as_ReducerOptions()->keep_dims();
865 loadOperationTo<ir::operation::Reduce>(op, subg, param);
868 template <typename LoaderDomain>
869 void BaseLoader<LoaderDomain>::loadReduceAll(const Operator *op, ir::Graph &subg)
871 ir::operation::Reduce::Param param;
872 param.reduce_type = ir::operation::Reduce::ReduceType::ALL;
873 if (op->custom_options() == nullptr)
875 param.keep_dims = false;
879 size_t custom_op_data_size = op->custom_options()->size();
880 auto custom_op_data = op->custom_options()->Data();
881 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
882 auto attr_map = data_root.AsMap();
883 param.keep_dims = attr_map["keep_dims"].AsBool();
886 loadOperationTo<ir::operation::Reduce>(op, subg, param);
889 template <typename LoaderDomain>
890 void BaseLoader<LoaderDomain>::loadElementwiseBinary(
891 const Operator *op, ir::Graph &subg,
892 ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type)
894 ir::operation::ElementwiseBinary::Param param;
895 param.op_type = op_type;
897 loadOperationTo<ir::operation::ElementwiseBinary>(op, subg, param);
900 template <typename LoaderDomain>
901 void BaseLoader<LoaderDomain>::loadElementwiseUnary(const Operator *op, ir::Graph &subg,
902 ir::operation::ElementwiseUnary::Type op_type)
904 ir::operation::ElementwiseUnary::Param param;
905 param.op_type = op_type;
907 const auto eu = loadOperationTo<ir::operation::ElementwiseUnary>(op, subg, param);
908 if (op_type == ir::operation::ElementwiseUnary::Type::CAST)
910 auto qasymm8ToUint8 = [](ir::Operand &operand) {
911 if (operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM)
913 operand.type(ir::DataType::UINT8);
917 subg.operands().at(eu->getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)));
918 qasymm8ToUint8(subg.operands().at(eu->getOutputs().at(0)));
922 template <typename LoaderDomain>
923 void BaseLoader<LoaderDomain>::loadGather(const Operator *op, ir::Graph &subg)
925 ir::operation::Gather::Param param;
926 param.axis = op->builtin_options_as_GatherOptions()->axis();
928 loadOperationTo<ir::operation::Gather>(op, subg, param);
931 template <typename LoaderDomain>
932 void BaseLoader<LoaderDomain>::loadDetectionPostProcess(const Operator *op, ir::Graph &subg)
934 const flexbuffers::Map &m =
935 flexbuffers::GetRoot(op->custom_options()->data(), op->custom_options()->size()).AsMap();
937 ir::operation::DetectionPostProcess::Param param;
939 param.max_detections = m["max_detections"].AsInt32();
942 param.max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
943 if (m["detections_per_class"].IsNull())
944 param.max_boxes_per_class = 100;
946 param.max_boxes_per_class = m["detections_per_class"].AsInt32();
948 if (m["use_regular_nms"].IsNull())
949 param.do_fast_eval = true;
951 param.do_fast_eval = !m["use_regular_nms"].AsBool();
953 param.score_threshold = m["nms_score_threshold"].AsFloat();
954 param.iou_threshold = m["nms_iou_threshold"].AsFloat();
956 // TODO add num classes support
957 param.num_classes = m["num_classes"].AsInt32();
959 param.scale.y_scale = m["y_scale"].AsFloat();
960 param.scale.x_scale = m["x_scale"].AsFloat();
961 param.scale.h_scale = m["h_scale"].AsFloat();
962 param.scale.w_scale = m["w_scale"].AsFloat();
964 // TODO depends on input model framework
965 param.center_size_boxes = true;
967 loadOperationTo<ir::operation::DetectionPostProcess>(op, subg, param);
970 template <typename LoaderDomain>
971 void BaseLoader<LoaderDomain>::loadBatchMatMul(const Operator *op, ir::Graph &subg)
973 ir::operation::BatchMatMul::Param param;
975 const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
979 case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
980 param.adj_x = op->builtin_options_as_BatchMatMulOptions()->adjoint_lhs();
981 param.adj_y = op->builtin_options_as_BatchMatMulOptions()->adjoint_rhs();
983 case BuiltinOperator::BuiltinOperator_CUSTOM:
984 if (op->custom_options() == nullptr)
991 size_t custom_op_data_size = op->custom_options()->size();
992 auto custom_op_data = op->custom_options()->Data();
993 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
994 auto attr_map = data_root.AsMap();
995 param.adj_x = attr_map["adj_x"].AsBool();
996 param.adj_y = attr_map["adj_y"].AsBool();
1000 throw std::runtime_error(
1001 std::string("Wrong loaded operation: ").append(EnumNameBuiltinOperator(builtin_op)) +
1002 " as " + EnumNameBuiltinOperator(BuiltinOperator::BuiltinOperator_BATCH_MATMUL));
1005 loadOperationTo<ir::operation::BatchMatMul>(op, subg, param);
1008 template <typename LoaderDomain>
1009 void BaseLoader<LoaderDomain>::loadSpaceToDepth(const Operator *op, ir::Graph &subg)
1011 ir::operation::SpaceToDepth::Param param;
1012 const auto *options = op->builtin_options_as_SpaceToDepthOptions();
1013 param.block_size = options->block_size();
1015 loadOperationTo<ir::operation::SpaceToDepth>(op, subg, param);
1018 template <typename LoaderDomain>
1019 void BaseLoader<LoaderDomain>::loadCustom(const Operator *op, ir::Graph &subg)
1021 ir::OperandIndexSequence inputs;
1022 ir::OperandIndexSequence outputs;
1024 assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS &&
1025 "Unsupported custom operation options format");
1027 auto *op_code = _model->operator_codes()->Get(op->opcode_index());
1028 auto custom_op_name = op_code->custom_code()->str();
1030 enum class BuiltinOP
1039 StatelessRandomUniform,
1041 DetectionPostProcess
1044 // Mapping from custom op name string to BuiltinOP enum
1045 std::map<std::string, BuiltinOP> builtin_map = {
1046 {"AddV2", BuiltinOP::AddV2},
1047 {"All", BuiltinOP::ReduceAll},
1048 {"MatrixBandPart", BuiltinOP::MatrixBandPart},
1049 {"BatchMatMulV2", BuiltinOP::BatchMatMul},
1050 {"Einsum", BuiltinOP::Einsum},
1051 {"FusedBatchNormV3", BuiltinOP::FusedBatchNorm},
1052 {"BroadcastTo", BuiltinOP::BroadcastTo},
1053 {"StatelessRandomUniform", BuiltinOP::StatelessRandomUniform},
1054 {"Erf", BuiltinOP::Erf},
1055 {"TFLite_Detection_PostProcess", BuiltinOP::DetectionPostProcess},
1060 // Throw out_of_range if it is unknown custom op
1061 auto custom_op_id = builtin_map.at(custom_op_name);
1062 switch (custom_op_id)
1064 case BuiltinOP::AddV2:
1065 loadAddV2(op, subg);
1067 case BuiltinOP::ReduceAll:
1068 loadReduceAll(op, subg);
1070 case BuiltinOP::MatrixBandPart:
1071 loadOperationTo<ir::operation::MatrixBandPart>(op, subg);
1073 case BuiltinOP::BatchMatMul:
1074 loadBatchMatMul(op, subg);
1076 case BuiltinOP::Einsum:
1077 loadEinsum(op, subg);
1079 case BuiltinOP::BroadcastTo:
1080 loadOperationTo<ir::operation::BroadcastTo>(op, subg);
1082 case BuiltinOP::FusedBatchNorm:
1083 loadFusedBatchNorm(op, subg);
1085 case BuiltinOP::StatelessRandomUniform:
1086 loadOperationTo<ir::operation::StatelessRandomUniform>(op, subg);
1088 case BuiltinOP::Erf:
1089 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ERF);
1091 case BuiltinOP::DetectionPostProcess:
1092 loadDetectionPostProcess(op, subg);
1095 throw std::runtime_error{
1096 "Loader: Custom OP map is defined but operation loader function is not defined"};
1103 loadOperationIO(op, inputs, outputs);
1105 auto constraint = ir::OperandConstraint::createExact(inputs.size());
1107 size_t custom_op_data_size = op->custom_options()->size();
1108 auto custom_op_data = new char[custom_op_data_size];
1109 std::copy(op->custom_options()->begin(), op->custom_options()->end(), custom_op_data);
1111 ir::operation::Custom::Userdata userdata{};
1112 userdata.data = custom_op_data;
1113 userdata.size = custom_op_data_size;
1115 auto new_op = std::make_unique<ir::operation::Custom>(constraint, inputs, outputs,
1116 custom_op_name, userdata);
1118 subg.addOperation(std::move(new_op));
1122 template <typename LoaderDomain>
1123 void BaseLoader<LoaderDomain>::loadSqueeze(const Operator *op, ir::Graph &subg)
1125 ir::operation::Squeeze::Param param;
1126 const auto *options = op->builtin_options_as_SqueezeOptions();
1127 const auto *dims = options->squeeze_dims();
1130 if (dims->size() > sizeof(param.dims) / sizeof(param.dims[0]))
1131 throw std::runtime_error("Squeeze: 'param.ndims' is out of range.");
1132 param.ndim = dims->size();
1133 for (int i = 0; i < param.ndim; ++i)
1134 param.dims[i] = dims->Get(i);
1137 loadOperationTo<ir::operation::Squeeze>(op, subg, param);
1140 template <typename LoaderDomain>
1141 void BaseLoader<LoaderDomain>::loadSplit(const Operator *op, ir::Graph &subg)
1143 ir::operation::Split::Param param;
1144 const auto *options = op->builtin_options_as_SplitOptions();
1145 param.num_splits = options->num_splits();
1147 loadOperationTo<ir::operation::Split>(op, subg, param);
1150 template <typename LoaderDomain>
1151 void BaseLoader<LoaderDomain>::loadSplitV(const Operator *op, ir::Graph &subg)
1153 ir::operation::SplitV::Param param;
1154 const auto *options = op->builtin_options_as_SplitVOptions();
1155 param.num_splits = options->num_splits();
1157 loadOperationTo<ir::operation::SplitV>(op, subg, param);
1160 template <typename LoaderDomain>
1161 void BaseLoader<LoaderDomain>::loadStridedSlice(const Operator *op, ir::Graph &subg)
1163 ir::operation::StridedSlice::Param param;
1164 const auto *options = op->builtin_options_as_StridedSliceOptions();
1165 param.begin_mask = options->begin_mask();
1166 param.end_mask = options->end_mask();
1167 param.shrink_axis_mask = options->shrink_axis_mask();
1169 loadOperationTo<ir::operation::StridedSlice>(op, subg, param);
1172 template <typename LoaderDomain>
1173 void BaseLoader<LoaderDomain>::loadUnpack(const Operator *op, ir::Graph &subg)
1175 ir::operation::Unpack::Param param;
1176 const auto *options = op->builtin_options_as_UnpackOptions();
1177 param.num = options->num();
1178 param.axis = options->axis();
1180 loadOperationTo<ir::operation::Unpack>(op, subg, param);
1183 template <typename LoaderDomain>
1184 void BaseLoader<LoaderDomain>::loadComparison(const Operator *op, ir::Graph &subg)
1186 ir::operation::Comparison::Param param;
1187 const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
1191 case BuiltinOperator::BuiltinOperator_EQUAL:
1192 param.comparison_type = ir::operation::Comparison::ComparisonType::Equal;
1194 case BuiltinOperator::BuiltinOperator_NOT_EQUAL:
1195 param.comparison_type = ir::operation::Comparison::ComparisonType::NotEqual;
1197 case BuiltinOperator::BuiltinOperator_GREATER_EQUAL:
1198 param.comparison_type = ir::operation::Comparison::ComparisonType::GreaterEqual;
1200 case BuiltinOperator::BuiltinOperator_GREATER:
1201 param.comparison_type = ir::operation::Comparison::ComparisonType::Greater;
1203 case BuiltinOperator::BuiltinOperator_LESS_EQUAL:
1204 param.comparison_type = ir::operation::Comparison::ComparisonType::LessEqual;
1206 case BuiltinOperator::BuiltinOperator_LESS:
1207 param.comparison_type = ir::operation::Comparison::ComparisonType::Less;
1210 throw std::runtime_error(
1211 std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
1214 loadOperationTo<ir::operation::Comparison>(op, subg, param);
1217 template <typename LoaderDomain>
1218 void BaseLoader<LoaderDomain>::loadEinsum(const Operator *op, ir::Graph &subg)
1220 ir::operation::Einsum::Param param;
1221 if (op->custom_options() == nullptr)
1223 throw std::runtime_error{"Einsum: empty equation"};
1227 size_t custom_op_data_size = op->custom_options()->size();
1228 auto custom_op_data = op->custom_options()->Data();
1229 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
1230 auto attr_map = data_root.AsMap();
1231 param.equation = attr_map["equation"].ToString();
1234 const auto es = loadOperationTo<ir::operation::Einsum>(op, subg, param);
1235 if (es->getInputs().size() != 2)
1237 throw std::runtime_error{"Einsum: NYI input - only support two inputs"};
1240 template <typename LoaderDomain>
1241 void BaseLoader<LoaderDomain>::loadFusedBatchNorm(const Operator *op, ir::Graph &subg)
1243 ir::operation::FusedBatchNorm::Param param;
1244 if (op->custom_options() == nullptr)
1246 throw std::runtime_error{"FusedBatchNorm: empty option"};
1250 size_t custom_op_data_size = op->custom_options()->size();
1251 auto custom_op_data = op->custom_options()->Data();
1252 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
1253 auto attr_map = data_root.AsMap();
1254 param.is_training = attr_map["is_training"].AsBool();
1255 param.epsilon = attr_map["epsilon"].AsFloat();
1256 param.data_format = attr_map["data_format"].ToString();
1259 const auto fbn = loadOperationTo<ir::operation::FusedBatchNorm>(op, subg, param);
1261 if (fbn->getInputs().size() != 5)
1263 throw std::runtime_error{"FusedBatchNorm: NYI input - only support five inputs"};
1267 template <typename LoaderDomain>
1268 void BaseLoader<LoaderDomain>::loadOneHot(const Operator *op, ir::Graph &subg)
1270 if (op->inputs()->size() != 4 || op->outputs()->size() != 1)
1271 throw std::runtime_error("OneHot Op has wrong number of input or output tensors.");
1274 ir::operation::OneHot::Param param;
1275 param.axis = op->builtin_options_as_OneHotOptions()->axis();
1277 loadOperationTo<ir::operation::OneHot>(op, subg, param);
1280 template <typename LoaderDomain>
1281 void BaseLoader<LoaderDomain>::loadIf(const Operator *op, ir::Graph &subg)
1283 const auto *options = op->builtin_options_as_IfOptions();
1284 const int32_t then_index = options->then_subgraph_index();
1285 const int32_t else_index = options->else_subgraph_index();
1287 verifySubgraphIndex(then_index);
1288 verifySubgraphIndex(else_index);
1290 ir::operation::If::Param param;
1291 param.then_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(then_index)};
1292 param.else_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(else_index)};
1294 loadOperationTo<ir::operation::If>(op, subg, param);
1297 template <typename LoaderDomain>
1298 void BaseLoader<LoaderDomain>::loadWhile(const Operator *op, ir::Graph &subg)
1300 const auto *options = op->builtin_options_as_WhileOptions();
1301 const int32_t cond_index = options->cond_subgraph_index();
1302 const int32_t body_index = options->body_subgraph_index();
1304 verifySubgraphIndex(cond_index);
1305 verifySubgraphIndex(body_index);
1307 ir::operation::While::Param param;
1308 param.cond_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(cond_index)};
1309 param.body_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(body_index)};
1311 loadOperationTo<ir::operation::While>(op, subg, param);
1314 template <typename LoaderDomain>
1315 void BaseLoader<LoaderDomain>::loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax)
1317 ir::operation::ArgMinMax::Param param;
1318 const auto output_type = is_argmax ? op->builtin_options_as_ArgMaxOptions()->output_type()
1319 : op->builtin_options_as_ArgMinOptions()->output_type();
1320 param.output_type = tensorTypeToDataType(output_type);
1321 param.is_arg_max = is_argmax;
1323 loadOperationTo<ir::operation::ArgMinMax>(op, subg, param);
1326 template <typename LoaderDomain>
1327 void BaseLoader<LoaderDomain>::loadLogSoftmax(const Operator *op, ir::Graph &subg)
1329 ir::operation::LogSoftmax::Param param;
1330 // In tflite, beta is fixed to 1.0 and axis is fixed to -1.
1334 loadOperationTo<ir::operation::LogSoftmax>(op, subg, param);
1337 template <typename LoaderDomain>
1338 void BaseLoader<LoaderDomain>::loadLeakyRelu(const Operator *op, ir::Graph &subg)
1340 float alpha = op->builtin_options_as_LeakyReluOptions()->alpha();
1341 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LEAKY_RELU, alpha,
1345 template <typename LoaderDomain>
1346 void BaseLoader<LoaderDomain>::loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg)
1348 ir::operation::LSTM::Param param;
1349 const auto *options = op->builtin_options_as_UnidirectionalSequenceLSTMOptions();
1350 param.activation = convertActivation(options->fused_activation_function());
1351 param.cell_threshold = options->cell_clip();
1352 param.projection_threshold = options->proj_clip();
1353 param.time_major = options->time_major();
1354 // The asymmetric_quantize_inputs option is unused yet
1356 ir::OperandIndexSequence inputs;
1357 for (const std::int32_t idx : *op->inputs())
1359 inputs.append(tensorIdxToOperandIdx(idx));
1362 ir::OperandIndexSequence outputs;
1363 // loader doesn't support optional output tensor yet
1364 if (op->outputs()->size() != 1)
1366 auto builtin_code = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
1367 throw std::runtime_error(std::string("loader doesn't support optional output tensor yet for ")
1368 .append(EnumNameBuiltinOperator(builtin_code)));
1370 for (size_t i = 0; i < ir::operation::LSTM::Output::OUTPUT; ++i)
1372 // Add optional outputs
1373 outputs.append(ir::OperandIndex());
1375 outputs.append(tensorIdxToOperandIdx(op->outputs()->Get(0)));
1377 std::unique_ptr<ir::operation::LSTM> new_op(new ir::operation::LSTM(inputs, outputs, param));
1378 subg.addOperation(std::move(new_op));
1381 template <typename LoaderDomain>
1382 void BaseLoader<LoaderDomain>::loadOperation(const Operator *op, ir::Graph &subg)
1384 const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
1388 case BuiltinOperator::BuiltinOperator_ADD_N:
1389 loadOperationTo<ir::operation::AddN>(op, subg);
1391 case BuiltinOperator::BuiltinOperator_CONV_2D:
1392 loadConv2D(op, subg);
1394 case BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D:
1395 loadPool2D(op, subg, ir::operation::Pool2D::PoolType::AVG);
1397 case BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D:
1398 loadDepthwiseConv2D(op, subg);
1400 case BuiltinOperator::BuiltinOperator_TRANSPOSE_CONV:
1401 loadTransposeConv(op, subg);
1403 case BuiltinOperator::BuiltinOperator_RESHAPE:
1404 loadReshape(op, subg);
1406 case BuiltinOperator::BuiltinOperator_SOFTMAX:
1407 loadSoftmax(op, subg);
1409 case BuiltinOperator::BuiltinOperator_MAX_POOL_2D:
1410 loadPool2D(op, subg, ir::operation::Pool2D::PoolType::MAX);
1412 case BuiltinOperator::BuiltinOperator_CONCATENATION:
1413 loadConcatenation(op, subg);
1415 case BuiltinOperator::BuiltinOperator_FLOOR:
1416 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::FLOOR);
1418 case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
1421 case BuiltinOperator::BuiltinOperator_ADD:
1422 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::ADD);
1424 case BuiltinOperator::BuiltinOperator_SUB:
1425 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::SUB);
1427 case BuiltinOperator::BuiltinOperator_MUL:
1428 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::MUL);
1430 case BuiltinOperator::BuiltinOperator_DIV:
1431 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::DIV);
1433 case BuiltinOperator::BuiltinOperator_PACK:
1436 case BuiltinOperator::BuiltinOperator_ELU:
1437 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::ELU);
1439 case BuiltinOperator::BuiltinOperator_RELU:
1440 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU,
1441 ir::operation::ElementwiseActivation::infinity, 0.f);
1443 case BuiltinOperator::BuiltinOperator_RELU_N1_TO_1:
1444 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 1.f,
1447 case BuiltinOperator::BuiltinOperator_RELU6:
1448 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 6.f,
1451 case BuiltinOperator::BuiltinOperator_RESIZE_BILINEAR:
1452 loadResizeBilinear(op, subg);
1454 case BuiltinOperator::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
1455 loadResizeNearestNeighbor(op, subg);
1457 case BuiltinOperator::BuiltinOperator_RSQRT:
1458 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::RSQRT);
1460 case BuiltinOperator::BuiltinOperator_SELECT:
1461 case BuiltinOperator::BuiltinOperator_SELECT_V2:
1462 loadOperationTo<ir::operation::Select>(op, subg);
1464 case BuiltinOperator::BuiltinOperator_SQRT:
1465 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SQRT);
1467 case BuiltinOperator::BuiltinOperator_SQUARE:
1468 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SQUARE);
1470 case BuiltinOperator::BuiltinOperator_SQUARED_DIFFERENCE:
1471 loadOperationTo<ir::operation::SquaredDifference>(op, subg);
1473 case BuiltinOperator::BuiltinOperator_TANH:
1474 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::TANH, 1.f,
1477 case BuiltinOperator::BuiltinOperator_TRANSPOSE:
1478 loadOperationTo<ir::operation::Transpose>(op, subg);
1480 case BuiltinOperator::BuiltinOperator_MEAN:
1481 loadReduce(op, subg, ir::operation::Reduce::ReduceType::MEAN);
1483 case BuiltinOperator::BuiltinOperator_REDUCE_ANY:
1484 loadReduce(op, subg, ir::operation::Reduce::ReduceType::ANY);
1486 case BuiltinOperator::BuiltinOperator_REDUCE_MAX:
1487 loadReduce(op, subg, ir::operation::Reduce::ReduceType::MAX);
1489 case BuiltinOperator::BuiltinOperator_REVERSE_V2:
1490 loadOperationTo<ir::operation::Reverse>(op, subg);
1492 case BuiltinOperator::BuiltinOperator_PAD:
1493 case BuiltinOperator::BuiltinOperator_PADV2:
1494 loadOperationTo<ir::operation::Pad>(op, subg);
1496 case BuiltinOperator::BuiltinOperator_LOGISTIC:
1497 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LOGISTIC);
1499 case BuiltinOperator::BuiltinOperator_EXP:
1500 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::EXP);
1502 case BuiltinOperator::BuiltinOperator_EXPAND_DIMS:
1503 loadOperationTo<ir::operation::ExpandDims>(op, subg);
1505 case BuiltinOperator::BuiltinOperator_GATHER:
1506 loadGather(op, subg);
1508 case BuiltinOperator::BuiltinOperator_SPACE_TO_BATCH_ND:
1509 loadOperationTo<ir::operation::SpaceToBatchND>(op, subg);
1511 case BuiltinOperator::BuiltinOperator_BATCH_TO_SPACE_ND:
1512 loadOperationTo<ir::operation::BatchToSpaceND>(op, subg);
1514 case BuiltinOperator::BuiltinOperator_SUM:
1515 loadReduce(op, subg, ir::operation::Reduce::ReduceType::SUM);
1517 case BuiltinOperator::BuiltinOperator_CUSTOM:
1518 loadCustom(op, subg);
1520 case BuiltinOperator::BuiltinOperator_SQUEEZE:
1521 loadSqueeze(op, subg);
1523 case BuiltinOperator::BuiltinOperator_PRELU:
1524 loadOperationTo<ir::operation::PReLU>(op, subg);
1526 case BuiltinOperator::BuiltinOperator_SPLIT:
1527 loadSplit(op, subg);
1529 case BuiltinOperator::BuiltinOperator_SPLIT_V:
1530 loadSplitV(op, subg);
1532 case BuiltinOperator::BuiltinOperator_SLICE:
1533 loadOperationTo<ir::operation::Slice>(op, subg);
1535 case BuiltinOperator::BuiltinOperator_STRIDED_SLICE:
1536 loadStridedSlice(op, subg);
1538 case BuiltinOperator::BuiltinOperator_UNPACK:
1539 loadUnpack(op, subg);
1541 case BuiltinOperator::BuiltinOperator_FLOOR_DIV:
1542 loadElementwiseBinary(op, subg,
1543 ir::operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_DIV);
1545 case BuiltinOperator::BuiltinOperator_MINIMUM:
1546 loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MIN);
1548 case BuiltinOperator::BuiltinOperator_MAXIMUM:
1549 loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MAX);
1551 case BuiltinOperator::BuiltinOperator_CAST:
1552 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::CAST);
1554 case BuiltinOperator::BuiltinOperator_EQUAL:
1555 case BuiltinOperator::BuiltinOperator_NOT_EQUAL:
1556 case BuiltinOperator::BuiltinOperator_GREATER_EQUAL:
1557 case BuiltinOperator::BuiltinOperator_GREATER:
1558 case BuiltinOperator::BuiltinOperator_LESS_EQUAL:
1559 case BuiltinOperator::BuiltinOperator_LESS:
1560 loadComparison(op, subg);
1562 case BuiltinOperator::BuiltinOperator_ONE_HOT:
1563 loadOneHot(op, subg);
1565 case BuiltinOperator::BuiltinOperator_ABS:
1566 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ABS);
1568 case BuiltinOperator::BuiltinOperator_COS:
1569 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::COS);
1571 case BuiltinOperator::BuiltinOperator_SIN:
1572 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SIN);
1574 case BuiltinOperator::BuiltinOperator_SHAPE:
1575 loadOperationTo<ir::operation::Shape>(op, subg);
1577 case BuiltinOperator::BuiltinOperator_REDUCE_PROD:
1578 loadReduce(op, subg, ir::operation::Reduce::ReduceType::PROD);
1580 case BuiltinOperator::BuiltinOperator_IF:
1583 case BuiltinOperator::BuiltinOperator_WHILE:
1584 loadWhile(op, subg);
1586 case BuiltinOperator::BuiltinOperator_NEG:
1587 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::NEG);
1589 case BuiltinOperator::BuiltinOperator_ARG_MAX:
1590 loadArgMinMax(op, subg, true);
1592 case BuiltinOperator::BuiltinOperator_ARG_MIN:
1593 loadArgMinMax(op, subg, false);
1595 case BuiltinOperator::BuiltinOperator_LOG:
1596 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOG);
1598 case BuiltinOperator::BuiltinOperator_ROUND:
1599 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ROUND);
1601 case BuiltinOperator::BuiltinOperator_POW:
1602 loadOperationTo<ir::operation::Pow>(op, subg);
1604 case BuiltinOperator::BuiltinOperator_LOGICAL_NOT:
1605 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOGICAL_NOT);
1607 case BuiltinOperator::BuiltinOperator_LOGICAL_AND:
1608 loadElementwiseBinary(op, subg,
1609 ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND);
1611 case BuiltinOperator::BuiltinOperator_LOGICAL_OR:
1612 loadElementwiseBinary(op, subg,
1613 ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR);
1615 case BuiltinOperator::BuiltinOperator_FILL:
1616 loadOperationTo<ir::operation::Fill>(op, subg);
1618 case BuiltinOperator::BuiltinOperator_ZEROS_LIKE:
1619 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ZEROS_LIKE);
1621 case BuiltinOperator::BuiltinOperator_TILE:
1622 loadOperationTo<ir::operation::Tile>(op, subg);
1624 case BuiltinOperator::BuiltinOperator_RANGE:
1625 loadOperationTo<ir::operation::Range>(op, subg);
1627 case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
1628 loadBatchMatMul(op, subg);
1630 case BuiltinOperator::BuiltinOperator_LOG_SOFTMAX:
1631 loadLogSoftmax(op, subg);
1633 case BuiltinOperator::BuiltinOperator_QUANTIZE:
1634 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::QUANTIZE);
1636 case BuiltinOperator::BuiltinOperator_DEQUANTIZE:
1637 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::DEQUANTIZE);
1639 case BuiltinOperator::BuiltinOperator_SPACE_TO_DEPTH:
1640 loadSpaceToDepth(op, subg);
1642 case BuiltinOperator::BuiltinOperator_L2_NORMALIZATION:
1643 loadOperationTo<ir::operation::L2Normalization>(op, subg);
1645 case BuiltinOperator::BuiltinOperator_LEAKY_RELU:
1646 loadLeakyRelu(op, subg);
1648 case BuiltinOperator::BuiltinOperator_RANK:
1649 loadOperationTo<ir::operation::Rank>(op, subg);
1651 case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
1652 loadUnidirectionalSequenceLSTM(op, subg);
1654 case BuiltinOperator::BuiltinOperator_DEPTH_TO_SPACE:
1655 loadDepthToSpace(op, subg);
1658 throw std::runtime_error(
1659 std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
1663 template <typename LoaderDomain> void BaseLoader<LoaderDomain>::loadModel()
1665 LoaderDomain::VerifyModelBuffer(*_verifier.get());
1666 _model = LoaderDomain::GetModel(_base);
1668 // const auto version = _model->version();
1669 // Description unused
1670 // const auto *description = _model->description();
1671 // Metabuffer unsued
1672 // const auto *metadata_buffer = _model->metadata_buffer();
1673 // Load subgraphs and map operations on subgraph
1674 const auto domain_subgraphs = _model->subgraphs();
1675 auto subgraphs = std::make_unique<ir::Subgraphs>();
1676 for (uint32_t subgraph_index = 0; subgraph_index < domain_subgraphs->size(); ++subgraph_index)
1678 auto subg = loadSubgraph((*_model->subgraphs())[subgraph_index]);
1679 subgraphs->push(ir::SubgraphIndex{subgraph_index}, std::move(subg));
1681 _subgraphs = std::move(subgraphs);
1684 } // namespace base_loader
1685 } // namespace onert
1687 #endif //__BASE_LOADER_BASE_LOADER_H__