2 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __BASE_LOADER_BASE_LOADER_H__
19 #define __BASE_LOADER_BASE_LOADER_H__
23 #include "ir/Operations.Include.h"
25 #include "flatbuffers/flexbuffers.h"
35 #include <util/logging.h>
42 template <typename LoaderDomain> class BaseLoader
45 using Verifier = typename LoaderDomain::Verifier;
46 using ActivationFunctionType = typename LoaderDomain::ActivationFunctionType;
47 using Buffer = typename LoaderDomain::Buffer;
48 using BuiltinOperator = typename LoaderDomain::BuiltinOperator;
49 using CustomOptionsFormat = typename LoaderDomain::CustomOptionsFormat;
50 using Model = typename LoaderDomain::Model;
51 using Operator = typename LoaderDomain::Operator;
52 using Padding = typename LoaderDomain::Padding;
53 using Pool2DOptions = typename LoaderDomain::Pool2DOptions;
54 using SubGraph = typename LoaderDomain::SubGraph;
55 using Tensor = typename LoaderDomain::Tensor;
56 using TensorType = typename LoaderDomain::TensorType;
57 using DimensionType = typename LoaderDomain::DimensionType;
58 using SparseIndexVector = typename LoaderDomain::SparseIndexVector;
61 bool isOptionalInputTensor(std::int32_t idx) { return idx == -1; }
62 virtual bool allowOptionalInputTensor(BuiltinOperator) = 0;
66 * @brief Construct a new Loader object
68 * @param graph reference on subgraphs
70 explicit BaseLoader(std::unique_ptr<ir::Subgraphs> &subgs)
71 : _base{nullptr}, _pagesize(getpagesize()), _fd(-1), _subgraphs(subgs), _model{nullptr}
73 _use_mmaped_data = util::getConfigBool(util::config::USE_MMAPED_DATA);
77 * @brief Load a model from file
81 void loadFromFile(const char *file_path);
83 * @brief Load a model from a buffer
85 * @param buffer buffer pointer
86 * @param size buffer size
88 void loadFromBuffer(uint8_t *buffer, size_t size);
91 ~BaseLoader() = default;
95 ir::Activation convertActivation(ActivationFunctionType type);
96 ir::DataType tensorTypeToDataType(TensorType type);
97 ir::OperandIndex tensorIdxToOperandIdx(int32_t tensorIdx);
99 // Create operands form tflite::Tensor
100 ir::OperandIndex loadOperand(const Tensor *tensor, ir::Graph &subg);
101 void loadSparsity(const Tensor *tensor, const ir::Shape &shape, ir::TypeInfo &typeInfo);
102 void loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs,
103 ir::OperandIndexSequence &outputs);
104 // Create operations from Operator
105 void loadOperation(const Operator *op, ir::Graph &subg);
106 // Load Strides and Paddings from options to param
107 template <typename Param, typename OptionsType>
108 void loadStridesAndPaddings(Param ¶m, const OptionsType *options);
110 template <typename Param> void loadPool2DOptions(Param ¶m, const Pool2DOptions *options);
113 virtual std::unique_ptr<ir::Graph> loadSubgraph(const SubGraph *subg) = 0;
115 template <typename OpIR, typename... Args>
116 const OpIR *loadOperationTo(const Operator *op, ir::Graph &subg, Args &&... args);
117 void loadConv2D(const Operator *op, ir::Graph &subg);
118 void loadDepthwiseConv2D(const Operator *op, ir::Graph &subg);
119 void loadTransposeConv(const Operator *op, ir::Graph &subg);
120 void loadPool2D(const Operator *op, ir::Graph &subg, ir::operation::Pool2D::PoolType op_type);
121 void loadReshape(const Operator *op, ir::Graph &subg);
122 void loadSoftmax(const Operator *op, ir::Graph &subg);
123 void loadConcatenation(const Operator *op, ir::Graph &subg);
124 void loadFC(const Operator *op, ir::Graph &subg);
125 void loadBinaryArithmetic(const Operator *op, ir::Graph &subg,
126 ir::operation::BinaryArithmetic::ArithmeticType op_type);
127 void loadAddV2(const Operator *op, ir::Graph &subg);
128 void loadPack(const Operator *op, ir::Graph &subg);
129 void loadResizeBilinear(const Operator *op, ir::Graph &subg);
130 void loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg);
131 void loadReduce(const Operator *op, ir::Graph &subg,
132 ir::operation::Reduce::ReduceType reduce_type);
133 void loadReduceAll(const Operator *op, ir::Graph &subg);
134 void loadElementwiseActivation(const Operator *op, ir::Graph &subg,
135 ir::operation::ElementwiseActivation::Type op_type,
136 float alpha = 0.f, float beta = 0.f);
137 void loadElementwiseBinary(const Operator *op, ir::Graph &subg,
138 ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type);
139 void loadElementwiseUnary(const Operator *op, ir::Graph &subg,
140 ir::operation::ElementwiseUnary::Type op_type);
141 void loadGather(const Operator *op, ir::Graph &subg);
142 void loadCustom(const Operator *op, ir::Graph &subg);
143 void loadBatchMatMul(const Operator *op, ir::Graph &subg);
144 void loadSqueeze(const Operator *op, ir::Graph &subg);
145 void loadSplit(const Operator *op, ir::Graph &subg);
146 void loadSplitV(const Operator *op, ir::Graph &subg);
147 void loadStridedSlice(const Operator *op, ir::Graph &subg);
148 void loadUnpack(const Operator *op, ir::Graph &subg);
149 void loadComparison(const Operator *op, ir::Graph &subg);
150 void loadEinsum(const Operator *op, ir::Graph &subg);
151 void loadOneHot(const Operator *op, ir::Graph &subg);
152 void loadIf(const Operator *op, ir::Graph &subg);
153 void loadWhile(const Operator *op, ir::Graph &subg);
154 void loadArgMax(const Operator *op, ir::Graph &subg);
155 void loadFusedBatchNorm(const Operator *op, ir::Graph &subg);
156 void loadLogSoftmax(const Operator *op, ir::Graph &subg);
157 void loadSpaceToDepth(const Operator *op, ir::Graph &subg);
158 void loadLeakyRelu(const Operator *op, ir::Graph &subg);
159 void loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg);
161 void verifySubgraphIndex(int subg_index)
163 const auto num_subgraphs = _model->subgraphs()->size();
164 if (subg_index < 0 || subg_index >= static_cast<int32_t>(num_subgraphs))
165 throw std::runtime_error{std::string{"Invalid subgraph index - "} +
166 std::to_string(subg_index)};
170 // Base address for mapped region for loading (if needed)
174 // loaded file description
176 // Reference on loadable subgraphs
177 std::unique_ptr<ir::Subgraphs> &_subgraphs;
179 // Maps Tensor indices to onert Operands.
180 std::vector<ir::OperandIndex> _tensor_to_operand;
181 std::unordered_map<ir::OperandIndex, std::string> _tensor_names;
183 std::unique_ptr<Verifier> _verifier;
184 // Boolean flag to use MMAPED_DATA
185 bool _use_mmaped_data = false;
188 template <typename LoaderDomain>
189 void BaseLoader<LoaderDomain>::BaseLoader::loadFromFile(const char *file_path)
191 _fd = open(file_path, O_RDONLY);
194 throw std::runtime_error("Failed to open file " + std::string(file_path));
197 struct stat file_stat;
198 if (fstat(_fd, &file_stat) != 0)
200 throw std::runtime_error("Fstat failed or file " + std::string(file_path) +
201 " is not a regular file");
203 int size = file_stat.st_size;
205 // Map model file into memory region
206 _base = static_cast<uint8_t *>(mmap(NULL, size, PROT_READ, MAP_PRIVATE, _fd, 0));
207 if (_base == MAP_FAILED)
210 throw std::runtime_error("mmap failed - " + std::string(strerror(errno)));
213 _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size);
221 template <typename LoaderDomain>
222 void BaseLoader<LoaderDomain>::BaseLoader::loadFromBuffer(uint8_t *buffer, size_t size)
225 _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size);
229 template <typename LoaderDomain>
231 BaseLoader<LoaderDomain>::BaseLoader::convertActivation(const ActivationFunctionType type)
235 case ActivationFunctionType::ActivationFunctionType_NONE:
236 return ir::Activation::NONE;
237 case ActivationFunctionType::ActivationFunctionType_RELU:
238 return ir::Activation::RELU;
239 case ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
240 return ir::Activation::RELU1;
241 case ActivationFunctionType::ActivationFunctionType_RELU6:
242 return ir::Activation::RELU6;
243 case ActivationFunctionType::ActivationFunctionType_TANH:
244 return ir::Activation::TANH;
246 throw std::runtime_error(std::string("Unsupported or invalid activation type: ") +
247 std::to_string(static_cast<int>(type)));
251 template <typename LoaderDomain>
252 ir::DataType BaseLoader<LoaderDomain>::BaseLoader::tensorTypeToDataType(const TensorType type)
256 case TensorType::TensorType_FLOAT32:
257 return ir::DataType::FLOAT32;
258 case TensorType::TensorType_INT32:
259 return ir::DataType::INT32;
260 case TensorType::TensorType_BOOL:
261 return ir::DataType::BOOL8;
262 case TensorType::TensorType_UINT8:
263 return ir::DataType::QUANT_UINT8_ASYMM;
264 case TensorType::TensorType_INT8:
265 return ir::DataType::QUANT_INT8_ASYMM;
266 case TensorType::TensorType_INT64:
267 return ir::DataType::INT64;
269 throw std::runtime_error(
270 std::string("Unsupported tensor type: ").append(EnumNameTensorType(type)));
274 template <typename LoaderDomain>
275 ir::OperandIndex BaseLoader<LoaderDomain>::BaseLoader::tensorIdxToOperandIdx(int32_t tensorIdx)
277 return isOptionalInputTensor(tensorIdx) ? ir::OperandIndex() : _tensor_to_operand[tensorIdx];
280 /* Copy is copied from tensorflow lite */
281 template <typename T> bool Copy(const T *data_ptr, std::vector<uint16_t> &arr)
283 if (data_ptr->values() == nullptr)
288 int size = data_ptr->values()->size();
290 for (int i = 0; i < size; i++)
292 arr.emplace_back(static_cast<uint16_t>(data_ptr->values()->Get(i)));
297 template <typename LoaderDomain>
298 ir::OperandIndex BaseLoader<LoaderDomain>::loadOperand(const Tensor *tensor, ir::Graph &subg)
302 const auto *tensor_shape = tensor->shape();
303 if (tensor_shape != nullptr)
305 for (const auto &dim : *tensor_shape)
311 // Note for tensor->shape_signature()
312 // We don't handle shape signature
314 // If shape_signature[k] == -1, we will use tensor->shape()[k] == 1
315 // If app wants to change the input shape, call nnfw_apply_input_tensorinfo() can
319 ir::DataType data_type = tensorTypeToDataType(tensor->type());
321 auto q_params = tensor->quantization();
324 if (q_params != nullptr)
326 if (q_params->scale())
328 if (q_params->scale()->size() != 1)
330 throw std::runtime_error("Only 1 scale for a tensor is supported.");
332 scale = q_params->scale()->Get(0);
335 if (q_params->zero_point())
337 if (q_params->zero_point()->size() != 1)
339 throw std::runtime_error("Only 1 zero_point value for a tensor is supported.");
341 zero_point = q_params->zero_point()->Get(0);
342 // zero_point is long while TypeInfo.zero_point is defined as int32_t.
343 assert(zero_point >= std::numeric_limits<int32_t>::min());
344 assert(zero_point <= std::numeric_limits<int32_t>::max());
346 auto details = q_params->details_as_CustomQuantization();
347 if (details != nullptr)
348 throw std::runtime_error("Custom Quantization is not supported");
351 ir::TypeInfo type_info(data_type, scale, zero_point);
353 loadSparsity(tensor, shape, type_info);
356 const auto operand_index = subg.addOperand(shape, type_info);
358 // Constant tensors are indicated by non-empty data.
359 const auto *data = _model->buffers()->Get(tensor->buffer())->data();
362 using std::ptrdiff_t;
363 std::unique_ptr<ir::Data> data_obj;
364 if (_fd == -1) // Model is from memory
366 data_obj = std::make_unique<ir::ExternalData>(data->data(), data->size());
368 else // Model is loaded(mmap'd) from a file
370 size_t data_size = data->size();
371 ptrdiff_t unaligned_offset_start = data->data() - _base;
372 ptrdiff_t offset_end = unaligned_offset_start + data_size;
374 // Calculated aligned offset from base address of mapped region
375 // munmap accepts memory address which is a multiple of the pagesize
376 ptrdiff_t aligned_offset_start = (unaligned_offset_start / _pagesize) * _pagesize;
377 size_t mmap_size = offset_end - aligned_offset_start;
379 if (_use_mmaped_data)
381 data_obj = std::make_unique<ir::MMapedData>(_fd, aligned_offset_start, mmap_size,
382 unaligned_offset_start, data_size);
386 size_t offset = unaligned_offset_start - aligned_offset_start;
387 uint8_t *mmap_base = static_cast<uint8_t *>(
388 mmap(NULL, mmap_size, PROT_READ, MAP_PRIVATE, _fd, aligned_offset_start));
389 data_obj = std::make_unique<ir::CachedData>(mmap_base + offset, data_size);
390 munmap(mmap_base, mmap_size);
393 subg.setOperandValue(operand_index, std::move(data_obj));
396 _tensor_names.emplace(operand_index, tensor->name()->str());
399 if (tensor->is_variable())
402 throw std::runtime_error("Variable tensor with buffer is not supported!");
404 subg.operands().at(operand_index).info().setAsVariable();
407 return operand_index;
410 template <typename LoaderDomain>
411 void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, const ir::Shape &shape,
412 ir::TypeInfo &typeInfo)
414 auto src_sparsity = tensor->sparsity();
415 if (src_sparsity != nullptr)
417 std::vector<uint16_t> w1_segments;
418 std::vector<uint16_t> w1_indices;
419 // check traversal_order
420 if (src_sparsity->traversal_order())
422 const int traversal_order_size = src_sparsity->traversal_order()->size();
423 for (int i = 0; i < traversal_order_size; ++i)
425 if (i != src_sparsity->traversal_order()->Get(i))
426 throw std::runtime_error("traversal_order [0, 1, ..., n-1] is only supported.");
431 if (src_sparsity->block_map())
433 block_rank = src_sparsity->block_map()->size();
434 for (int i = 0; i < block_rank; ++i)
436 if (i != src_sparsity->block_map()->Get(i))
437 throw std::runtime_error("block_map [0, 1, ..., n-1] is only supported.");
441 const int dim_metadata_size = src_sparsity->dim_metadata()->size();
442 auto dense_rank = shape.rank();
443 if (dense_rank + block_rank != dim_metadata_size)
444 throw std::runtime_error("sparsity dim_metadata length is wrong.");
445 bool random_sparsity = dim_metadata_size == 2 && block_rank == 0;
446 bool block2D_sparsity = dim_metadata_size == 4 && block_rank == 2;
447 if (dim_metadata_size != !random_sparsity && !block2D_sparsity)
448 throw std::runtime_error(
449 "sparsity is supported only for 2D tensor with random or 16x1 block sparsity.");
451 const auto *src_metadata = src_sparsity->dim_metadata()->Get(0);
452 if (src_metadata->format() != DimensionType::DimensionType_DENSE)
453 throw std::runtime_error("sparse tensor dim[0] is not DENSE");
454 src_metadata = src_sparsity->dim_metadata()->Get(1);
455 if (src_metadata->format() != DimensionType::DimensionType_SPARSE_CSR)
456 throw std::runtime_error("sparse tensor dim[0] is not SPARSE_CSR");
457 auto ParseSparseIndexVector = [src_metadata, &w1_segments, &w1_indices]() {
458 if (src_metadata->array_segments() == nullptr || src_metadata->array_indices() == nullptr)
461 switch (src_metadata->array_segments_type())
463 case SparseIndexVector::SparseIndexVector_Int32Vector:
464 status = Copy(src_metadata->array_segments_as_Int32Vector(), w1_segments);
466 case SparseIndexVector::SparseIndexVector_Uint16Vector:
467 status = Copy(src_metadata->array_segments_as_Uint16Vector(), w1_segments);
469 case SparseIndexVector::SparseIndexVector_Uint8Vector:
470 status = Copy(src_metadata->array_segments_as_Uint8Vector(), w1_segments);
477 switch (src_metadata->array_indices_type())
479 case SparseIndexVector::SparseIndexVector_Int32Vector:
480 return Copy(src_metadata->array_indices_as_Int32Vector(), w1_indices);
481 case SparseIndexVector::SparseIndexVector_Uint16Vector:
482 return Copy(src_metadata->array_indices_as_Uint16Vector(), w1_indices);
483 case SparseIndexVector::SparseIndexVector_Uint8Vector:
484 return Copy(src_metadata->array_indices_as_Uint8Vector(), w1_indices);
490 if (ParseSparseIndexVector() == false)
491 throw std::runtime_error("Error during parsing sparsity index information");
493 std::vector<int32_t> block_size;
494 for (int i = 0; i < block_rank; ++i)
496 auto block_metadata = src_sparsity->dim_metadata()->Get(dense_rank + i);
497 if (block_metadata->format() != DimensionType::DimensionType_DENSE)
498 throw std::runtime_error("block dimension must be DENSE.");
499 block_size.push_back(block_metadata->dense_size());
501 typeInfo.sparsity(std::make_shared<ir::Sparsity>(std::move(w1_segments), std::move(w1_indices),
502 std::move(block_size)));
506 template <typename LoaderDomain>
507 void BaseLoader<LoaderDomain>::loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs,
508 ir::OperandIndexSequence &outputs)
510 for (const std::int32_t idx : *op->inputs())
512 // Optional tensors are not supported yet except for FULLY_CONNECTED and BCQ_FULLY_CONNECTED
513 auto check_optional_input = [&]() {
514 auto builtin_code = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
515 if (isOptionalInputTensor(idx) && !allowOptionalInputTensor(builtin_code))
516 throw std::runtime_error(
517 std::string("loader doesn't support optional input tensor yet for ")
518 .append(EnumNameBuiltinOperator(builtin_code)));
520 check_optional_input();
521 inputs.append(tensorIdxToOperandIdx(idx));
524 for (const std::int32_t idx : *op->outputs())
526 outputs.append(tensorIdxToOperandIdx(idx));
530 template <typename LoaderDomain>
531 template <typename Param, typename OptionsType>
532 void BaseLoader<LoaderDomain>::loadStridesAndPaddings(Param ¶m, const OptionsType *options)
535 param.stride.vertical = options->stride_h();
536 param.stride.horizontal = options->stride_w();
538 switch (options->padding())
540 case Padding::Padding_SAME:
541 param.padding.type = ir::PaddingType::SAME;
543 case Padding::Padding_VALID:
544 param.padding.type = ir::PaddingType::VALID;
547 throw std::runtime_error{"Invalid padding type"};
549 // param paddings indexes unused
552 template <typename LoaderDomain>
553 template <typename Param>
554 void BaseLoader<LoaderDomain>::loadPool2DOptions(Param ¶m, const Pool2DOptions *options)
556 // Strides and Paddings
557 if (options->stride_h() <= 0 || options->stride_w() <= 0)
558 throw std::runtime_error{"Invalid stride vertical or horizontal - both must be bigger than 0"};
559 loadStridesAndPaddings(param, options);
560 // Filter width and height
562 if (options->filter_width() <= 0 || options->filter_height() <= 0)
563 throw std::runtime_error{"Invalid filter width or height - both must be bigger than 0"};
564 param.kw = options->filter_width();
565 param.kh = options->filter_height();
567 param.activation = convertActivation(options->fused_activation_function());
570 template <typename LoaderDomain>
571 template <typename OpIR, typename... Args>
572 const OpIR *BaseLoader<LoaderDomain>::loadOperationTo(const Operator *op, ir::Graph &subg,
575 static_assert(sizeof...(args) <= 1, "You can't have more than 1 arguments!");
576 ir::OperandIndexSequence inputs;
577 ir::OperandIndexSequence outputs;
579 loadOperationIO(op, inputs, outputs);
581 std::unique_ptr<OpIR> new_op(new OpIR(inputs, outputs, std::forward<Args>(args)...));
582 auto ret = new_op.get();
583 subg.addOperation(std::move(new_op));
588 template <typename LoaderDomain>
589 void BaseLoader<LoaderDomain>::loadConv2D(const Operator *op, ir::Graph &subg)
591 ir::operation::Conv2D::Param param;
592 const auto *options = op->builtin_options_as_Conv2DOptions();
593 param.activation = convertActivation(options->fused_activation_function());
594 loadStridesAndPaddings(param, options);
595 param.dilation.width_factor = options->dilation_w_factor();
596 param.dilation.height_factor = options->dilation_h_factor();
598 loadOperationTo<ir::operation::Conv2D>(op, subg, param);
601 template <typename LoaderDomain>
602 void BaseLoader<LoaderDomain>::loadDepthwiseConv2D(const Operator *op, ir::Graph &subg)
604 ir::operation::DepthwiseConv2D::Param param;
605 const auto *options = op->builtin_options_as_DepthwiseConv2DOptions();
606 param.activation = convertActivation(options->fused_activation_function());
607 loadStridesAndPaddings(param, options);
608 param.multiplier = options->depth_multiplier();
609 // Dilation h/w factor unused
610 param.dilation.width_factor = options->dilation_w_factor();
611 param.dilation.height_factor = options->dilation_h_factor();
613 loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param);
616 template <typename LoaderDomain>
617 void BaseLoader<LoaderDomain>::loadTransposeConv(const Operator *op, ir::Graph &subg)
619 ir::operation::TransposeConv::Param param;
620 const auto *options = op->builtin_options_as_TransposeConvOptions();
621 loadStridesAndPaddings(param, options);
623 loadOperationTo<ir::operation::TransposeConv>(op, subg, param);
626 template <typename LoaderDomain>
627 void BaseLoader<LoaderDomain>::loadPool2D(const Operator *op, ir::Graph &subg,
628 ir::operation::Pool2D::PoolType op_type)
630 ir::operation::Pool2D::Param param;
631 param.op_type = op_type;
632 const auto *options = op->builtin_options_as_Pool2DOptions();
634 loadPool2DOptions(param, options);
636 loadOperationTo<ir::operation::Pool2D>(op, subg, param);
639 template <typename LoaderDomain>
640 void BaseLoader<LoaderDomain>::loadReshape(const Operator *op, ir::Graph &subg)
642 ir::operation::Reshape::Param param{};
643 const auto *options = op->builtin_options_as_ReshapeOptions();
644 if (options != nullptr)
646 const auto *new_shape = options->new_shape();
649 for (uint i = 0; i < new_shape->size(); ++i)
651 param.new_shape.push_back(new_shape->Get(i));
656 loadOperationTo<ir::operation::Reshape>(op, subg, param);
659 template <typename LoaderDomain>
660 void BaseLoader<LoaderDomain>::loadSoftmax(const Operator *op, ir::Graph &subg)
662 ir::operation::Softmax::Param param;
663 const auto *options = op->builtin_options_as_SoftmaxOptions();
665 param.beta = options->beta();
667 loadOperationTo<ir::operation::Softmax>(op, subg, param);
670 template <typename LoaderDomain>
671 void BaseLoader<LoaderDomain>::loadConcatenation(const Operator *op, ir::Graph &subg)
673 ir::operation::Concat::Param param;
674 const auto *options = op->builtin_options_as_ConcatenationOptions();
676 param.axis = options->axis();
679 loadOperationTo<ir::operation::Concat>(op, subg, param);
682 template <typename LoaderDomain>
683 void BaseLoader<LoaderDomain>::loadFC(const Operator *op, ir::Graph &subg)
685 ir::operation::FullyConnected::Param param;
686 const auto *options = op->builtin_options_as_FullyConnectedOptions();
688 param.activation = convertActivation(options->fused_activation_function());
689 param.weights_format = static_cast<ir::FullyConnectedWeightsFormat>(options->weights_format());
691 const auto fc = loadOperationTo<ir::operation::FullyConnected>(op, subg, param);
693 const auto &input_operand =
694 subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::INPUT));
695 auto &weights_operand =
696 subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::WEIGHT));
697 if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 &&
698 ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) ||
699 weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM))
701 weights_operand.type(ir::DataType::QUANT_INT8_SYMM);
705 template <typename LoaderDomain>
706 void BaseLoader<LoaderDomain>::loadAddV2(const Operator *op, ir::Graph &subg)
708 ir::operation::BinaryArithmetic::Param param;
709 param.arithmetic_type = ir::operation::BinaryArithmetic::ArithmeticType::ADD;
711 if (op->custom_options() == nullptr)
713 param.activation = ir::Activation::NONE;
717 size_t custom_op_data_size = op->custom_options()->size();
718 auto custom_op_data = op->custom_options()->Data();
719 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
720 auto attr_map = data_root.AsMap();
721 const auto fused_activation_func = static_cast<typename LoaderDomain::ActivationFunctionType>(
722 attr_map["fused_activation_function"].AsInt8());
723 param.activation = convertActivation(fused_activation_func);
726 loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param);
729 template <typename LoaderDomain>
730 void BaseLoader<LoaderDomain>::loadBinaryArithmetic(
731 const Operator *op, ir::Graph &subg, ir::operation::BinaryArithmetic::ArithmeticType op_type)
733 ir::operation::BinaryArithmetic::Param param;
734 param.arithmetic_type = op_type;
737 case ir::operation::BinaryArithmetic::ArithmeticType::ADD:
739 const auto *add_options = op->builtin_options_as_AddOptions();
740 param.activation = convertActivation(add_options->fused_activation_function());
743 case ir::operation::BinaryArithmetic::ArithmeticType::SUB:
745 const auto *sub_options = op->builtin_options_as_SubOptions();
746 param.activation = convertActivation(sub_options->fused_activation_function());
749 case ir::operation::BinaryArithmetic::ArithmeticType::MUL:
751 const auto *mul_options = op->builtin_options_as_MulOptions();
752 param.activation = convertActivation(mul_options->fused_activation_function());
755 case ir::operation::BinaryArithmetic::ArithmeticType::DIV:
757 const auto *div_options = op->builtin_options_as_DivOptions();
758 param.activation = convertActivation(div_options->fused_activation_function());
763 "The function 'loadBinaryArithmetic' supports only BinaryArithmetic operations");
767 loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param);
770 template <typename LoaderDomain>
771 void BaseLoader<LoaderDomain>::loadPack(const Operator *op, ir::Graph &subg)
773 ir::operation::Pack::Param param;
774 const auto *options = op->builtin_options_as_PackOptions();
775 param.num = options->values_count();
776 param.axis = options->axis();
778 loadOperationTo<ir::operation::Pack>(op, subg, param);
781 template <typename LoaderDomain>
782 void BaseLoader<LoaderDomain>::loadElementwiseActivation(
783 const Operator *op, ir::Graph &subg, ir::operation::ElementwiseActivation::Type op_type,
784 float alpha, float beta)
786 ir::operation::ElementwiseActivation::Param param;
787 param.op_type = op_type;
791 loadOperationTo<ir::operation::ElementwiseActivation>(op, subg, param);
794 template <typename LoaderDomain>
795 void BaseLoader<LoaderDomain>::loadResizeBilinear(const Operator *op, ir::Graph &subg)
797 ir::operation::ResizeBilinear::Param param;
798 param.align_corners = op->builtin_options_as_ResizeBilinearOptions()->align_corners();
799 param.half_pixel_centers = op->builtin_options_as_ResizeBilinearOptions()->half_pixel_centers();
801 loadOperationTo<ir::operation::ResizeBilinear>(op, subg, param);
804 template <typename LoaderDomain>
805 void BaseLoader<LoaderDomain>::loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg)
807 ir::operation::ResizeNearestNeighbor::Param param;
808 param.align_corners = op->builtin_options_as_ResizeNearestNeighborOptions()->align_corners();
810 loadOperationTo<ir::operation::ResizeNearestNeighbor>(op, subg, param);
813 template <typename LoaderDomain>
814 void BaseLoader<LoaderDomain>::loadReduce(const Operator *op, ir::Graph &subg,
815 ir::operation::Reduce::ReduceType reduce_type)
817 ir::operation::Reduce::Param param;
818 param.reduce_type = reduce_type;
819 param.keep_dims = op->builtin_options_as_ReducerOptions()->keep_dims();
821 loadOperationTo<ir::operation::Reduce>(op, subg, param);
824 template <typename LoaderDomain>
825 void BaseLoader<LoaderDomain>::loadReduceAll(const Operator *op, ir::Graph &subg)
827 ir::operation::Reduce::Param param;
828 param.reduce_type = ir::operation::Reduce::ReduceType::ALL;
829 if (op->custom_options() == nullptr)
831 param.keep_dims = false;
835 size_t custom_op_data_size = op->custom_options()->size();
836 auto custom_op_data = op->custom_options()->Data();
837 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
838 auto attr_map = data_root.AsMap();
839 param.keep_dims = attr_map["keep_dims"].AsBool();
842 loadOperationTo<ir::operation::Reduce>(op, subg, param);
845 template <typename LoaderDomain>
846 void BaseLoader<LoaderDomain>::loadElementwiseBinary(
847 const Operator *op, ir::Graph &subg,
848 ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type)
850 ir::operation::ElementwiseBinary::Param param;
851 param.op_type = op_type;
853 loadOperationTo<ir::operation::ElementwiseBinary>(op, subg, param);
856 template <typename LoaderDomain>
857 void BaseLoader<LoaderDomain>::loadElementwiseUnary(const Operator *op, ir::Graph &subg,
858 ir::operation::ElementwiseUnary::Type op_type)
860 ir::operation::ElementwiseUnary::Param param;
861 param.op_type = op_type;
863 const auto eu = loadOperationTo<ir::operation::ElementwiseUnary>(op, subg, param);
864 if (op_type == ir::operation::ElementwiseUnary::Type::CAST)
866 auto qasymm8ToUint8 = [](ir::Operand &operand) {
867 if (operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM)
869 operand.type(ir::DataType::UINT8);
873 subg.operands().at(eu->getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)));
874 qasymm8ToUint8(subg.operands().at(eu->getOutputs().at(0)));
878 template <typename LoaderDomain>
879 void BaseLoader<LoaderDomain>::loadGather(const Operator *op, ir::Graph &subg)
881 ir::operation::Gather::Param param;
882 param.axis = op->builtin_options_as_GatherOptions()->axis();
884 loadOperationTo<ir::operation::Gather>(op, subg, param);
887 template <typename LoaderDomain>
888 void BaseLoader<LoaderDomain>::loadBatchMatMul(const Operator *op, ir::Graph &subg)
890 ir::operation::BatchMatMul::Param param;
892 const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
896 case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
897 param.adj_x = op->builtin_options_as_BatchMatMulOptions()->adjoint_lhs();
898 param.adj_y = op->builtin_options_as_BatchMatMulOptions()->adjoint_rhs();
900 case BuiltinOperator::BuiltinOperator_CUSTOM:
901 if (op->custom_options() == nullptr)
908 size_t custom_op_data_size = op->custom_options()->size();
909 auto custom_op_data = op->custom_options()->Data();
910 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
911 auto attr_map = data_root.AsMap();
912 param.adj_x = attr_map["adj_x"].AsBool();
913 param.adj_y = attr_map["adj_y"].AsBool();
917 throw std::runtime_error(
918 std::string("Wrong loaded operation: ").append(EnumNameBuiltinOperator(builtin_op)) +
919 " as " + EnumNameBuiltinOperator(BuiltinOperator::BuiltinOperator_BATCH_MATMUL));
922 loadOperationTo<ir::operation::BatchMatMul>(op, subg, param);
925 template <typename LoaderDomain>
926 void BaseLoader<LoaderDomain>::loadSpaceToDepth(const Operator *op, ir::Graph &subg)
928 ir::operation::SpaceToDepth::Param param;
929 const auto *options = op->builtin_options_as_SpaceToDepthOptions();
930 param.block_size = options->block_size();
932 loadOperationTo<ir::operation::SpaceToDepth>(op, subg, param);
935 template <typename LoaderDomain>
936 void BaseLoader<LoaderDomain>::loadCustom(const Operator *op, ir::Graph &subg)
938 ir::OperandIndexSequence inputs;
939 ir::OperandIndexSequence outputs;
941 assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS &&
942 "Unsupported custom operation options format");
944 auto *op_code = _model->operator_codes()->Get(op->opcode_index());
945 auto custom_op_name = op_code->custom_code()->str();
956 StatelessRandomUniform,
960 // Mapping from custom op name string to BuiltinOP enum
961 std::map<std::string, BuiltinOP> builtin_map = {
962 {"AddV2", BuiltinOP::AddV2},
963 {"All", BuiltinOP::ReduceAll},
964 {"MatrixBandPart", BuiltinOP::MatrixBandPart},
965 {"BatchMatMulV2", BuiltinOP::BatchMatMul},
966 {"Einsum", BuiltinOP::Einsum},
967 {"FusedBatchNormV3", BuiltinOP::FusedBatchNorm},
968 {"BroadcastTo", BuiltinOP::BroadcastTo},
969 {"StatelessRandomUniform", BuiltinOP::StatelessRandomUniform},
970 {"Erf", BuiltinOP::Erf},
975 // Throw out_of_range if it is unknown custom op
976 auto custom_op_id = builtin_map.at(custom_op_name);
977 switch (custom_op_id)
979 case BuiltinOP::AddV2:
982 case BuiltinOP::ReduceAll:
983 loadReduceAll(op, subg);
985 case BuiltinOP::MatrixBandPart:
986 loadOperationTo<ir::operation::MatrixBandPart>(op, subg);
988 case BuiltinOP::BatchMatMul:
989 loadBatchMatMul(op, subg);
991 case BuiltinOP::Einsum:
992 loadEinsum(op, subg);
994 case BuiltinOP::BroadcastTo:
995 loadOperationTo<ir::operation::BroadcastTo>(op, subg);
997 case BuiltinOP::FusedBatchNorm:
998 loadFusedBatchNorm(op, subg);
1000 case BuiltinOP::StatelessRandomUniform:
1001 loadOperationTo<ir::operation::StatelessRandomUniform>(op, subg);
1003 case BuiltinOP::Erf:
1004 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ERF);
1007 throw std::runtime_error{
1008 "Loader: Custom OP map is defined but operation loader function is not defined"};
1015 loadOperationIO(op, inputs, outputs);
1017 auto constraint = ir::OperandConstraint::createExact(inputs.size());
1019 size_t custom_op_data_size = op->custom_options()->size();
1020 auto custom_op_data = new char[custom_op_data_size];
1021 std::copy(op->custom_options()->begin(), op->custom_options()->end(), custom_op_data);
1023 ir::operation::Custom::Userdata userdata{};
1024 userdata.data = custom_op_data;
1025 userdata.size = custom_op_data_size;
1027 auto new_op = std::make_unique<ir::operation::Custom>(constraint, inputs, outputs,
1028 custom_op_name, userdata);
1030 subg.addOperation(std::move(new_op));
1034 template <typename LoaderDomain>
1035 void BaseLoader<LoaderDomain>::loadSqueeze(const Operator *op, ir::Graph &subg)
1037 ir::operation::Squeeze::Param param;
1038 const auto *options = op->builtin_options_as_SqueezeOptions();
1039 const auto *dims = options->squeeze_dims();
1042 if (dims->size() > sizeof(param.dims) / sizeof(param.dims[0]))
1043 throw std::runtime_error("Squeeze: 'param.ndims' is out of range.");
1044 param.ndim = dims->size();
1045 for (int i = 0; i < param.ndim; ++i)
1046 param.dims[i] = dims->Get(i);
1049 loadOperationTo<ir::operation::Squeeze>(op, subg, param);
1052 template <typename LoaderDomain>
1053 void BaseLoader<LoaderDomain>::loadSplit(const Operator *op, ir::Graph &subg)
1055 ir::operation::Split::Param param;
1056 const auto *options = op->builtin_options_as_SplitOptions();
1057 param.num_splits = options->num_splits();
1059 loadOperationTo<ir::operation::Split>(op, subg, param);
1062 template <typename LoaderDomain>
1063 void BaseLoader<LoaderDomain>::loadSplitV(const Operator *op, ir::Graph &subg)
1065 ir::operation::SplitV::Param param;
1066 const auto *options = op->builtin_options_as_SplitVOptions();
1067 param.num_splits = options->num_splits();
1069 loadOperationTo<ir::operation::SplitV>(op, subg, param);
1072 template <typename LoaderDomain>
1073 void BaseLoader<LoaderDomain>::loadStridedSlice(const Operator *op, ir::Graph &subg)
1075 ir::operation::StridedSlice::Param param;
1076 const auto *options = op->builtin_options_as_StridedSliceOptions();
1077 param.begin_mask = options->begin_mask();
1078 param.end_mask = options->end_mask();
1079 param.shrink_axis_mask = options->shrink_axis_mask();
1081 loadOperationTo<ir::operation::StridedSlice>(op, subg, param);
1084 template <typename LoaderDomain>
1085 void BaseLoader<LoaderDomain>::loadUnpack(const Operator *op, ir::Graph &subg)
1087 ir::operation::Unpack::Param param;
1088 const auto *options = op->builtin_options_as_UnpackOptions();
1089 param.num = options->num();
1090 param.axis = options->axis();
1092 loadOperationTo<ir::operation::Unpack>(op, subg, param);
1095 template <typename LoaderDomain>
1096 void BaseLoader<LoaderDomain>::loadComparison(const Operator *op, ir::Graph &subg)
1098 ir::operation::Comparison::Param param;
1099 const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
1103 case BuiltinOperator::BuiltinOperator_EQUAL:
1104 param.comparison_type = ir::operation::Comparison::ComparisonType::Equal;
1106 case BuiltinOperator::BuiltinOperator_NOT_EQUAL:
1107 param.comparison_type = ir::operation::Comparison::ComparisonType::NotEqual;
1109 case BuiltinOperator::BuiltinOperator_GREATER_EQUAL:
1110 param.comparison_type = ir::operation::Comparison::ComparisonType::GreaterEqual;
1112 case BuiltinOperator::BuiltinOperator_GREATER:
1113 param.comparison_type = ir::operation::Comparison::ComparisonType::Greater;
1115 case BuiltinOperator::BuiltinOperator_LESS_EQUAL:
1116 param.comparison_type = ir::operation::Comparison::ComparisonType::LessEqual;
1118 case BuiltinOperator::BuiltinOperator_LESS:
1119 param.comparison_type = ir::operation::Comparison::ComparisonType::Less;
1122 throw std::runtime_error(
1123 std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
1126 loadOperationTo<ir::operation::Comparison>(op, subg, param);
1129 template <typename LoaderDomain>
1130 void BaseLoader<LoaderDomain>::loadEinsum(const Operator *op, ir::Graph &subg)
1132 ir::operation::Einsum::Param param;
1133 if (op->custom_options() == nullptr)
1135 throw std::runtime_error{"Einsum: empty equation"};
1139 size_t custom_op_data_size = op->custom_options()->size();
1140 auto custom_op_data = op->custom_options()->Data();
1141 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
1142 auto attr_map = data_root.AsMap();
1143 param.equation = attr_map["equation"].ToString();
1146 const auto es = loadOperationTo<ir::operation::Einsum>(op, subg, param);
1147 if (es->getInputs().size() != 2)
1149 throw std::runtime_error{"Einsum: NYI input - only support two inputs"};
1152 template <typename LoaderDomain>
1153 void BaseLoader<LoaderDomain>::loadFusedBatchNorm(const Operator *op, ir::Graph &subg)
1155 ir::operation::FusedBatchNorm::Param param;
1156 if (op->custom_options() == nullptr)
1158 throw std::runtime_error{"FusedBatchNorm: empty option"};
1162 size_t custom_op_data_size = op->custom_options()->size();
1163 auto custom_op_data = op->custom_options()->Data();
1164 auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
1165 auto attr_map = data_root.AsMap();
1166 param.is_training = attr_map["is_training"].AsBool();
1167 param.epsilon = attr_map["epsilon"].AsFloat();
1168 param.data_format = attr_map["data_format"].ToString();
1171 const auto fbn = loadOperationTo<ir::operation::FusedBatchNorm>(op, subg, param);
1173 if (fbn->getInputs().size() != 5)
1175 throw std::runtime_error{"FusedBatchNorm: NYI input - only support five inputs"};
1179 template <typename LoaderDomain>
1180 void BaseLoader<LoaderDomain>::loadOneHot(const Operator *op, ir::Graph &subg)
1182 if (op->inputs()->size() != 4 || op->outputs()->size() != 1)
1183 throw std::runtime_error("OneHot Op has wrong number of input or output tensors.");
1186 ir::operation::OneHot::Param param;
1187 param.axis = op->builtin_options_as_OneHotOptions()->axis();
1189 loadOperationTo<ir::operation::OneHot>(op, subg, param);
1192 template <typename LoaderDomain>
1193 void BaseLoader<LoaderDomain>::loadIf(const Operator *op, ir::Graph &subg)
1195 const auto *options = op->builtin_options_as_IfOptions();
1196 const int32_t then_index = options->then_subgraph_index();
1197 const int32_t else_index = options->else_subgraph_index();
1199 verifySubgraphIndex(then_index);
1200 verifySubgraphIndex(else_index);
1202 ir::operation::If::Param param;
1203 param.then_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(then_index)};
1204 param.else_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(else_index)};
1206 loadOperationTo<ir::operation::If>(op, subg, param);
1209 template <typename LoaderDomain>
1210 void BaseLoader<LoaderDomain>::loadWhile(const Operator *op, ir::Graph &subg)
1212 const auto *options = op->builtin_options_as_WhileOptions();
1213 const int32_t cond_index = options->cond_subgraph_index();
1214 const int32_t body_index = options->body_subgraph_index();
1216 verifySubgraphIndex(cond_index);
1217 verifySubgraphIndex(body_index);
1219 ir::operation::While::Param param;
1220 param.cond_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(cond_index)};
1221 param.body_subg_index = ir::SubgraphIndex{static_cast<uint32_t>(body_index)};
1223 loadOperationTo<ir::operation::While>(op, subg, param);
1226 template <typename LoaderDomain>
1227 void BaseLoader<LoaderDomain>::loadArgMax(const Operator *op, ir::Graph &subg)
1229 ir::operation::ArgMax::Param param;
1230 const auto output_type = op->builtin_options_as_ArgMaxOptions()->output_type();
1231 switch (output_type)
1233 case TensorType::TensorType_INT32:
1234 case TensorType::TensorType_INT64:
1235 param.output_type = tensorTypeToDataType(output_type);
1238 throw std::runtime_error("ArgMax: `output_type` must be either int32 or int64.");
1240 auto am = loadOperationTo<ir::operation::ArgMax>(op, subg, param);
1242 auto &axisOperand = subg.operands().at(am->getInputs().at(ir::operation::ArgMax::Input::AXIS));
1243 if (!(axisOperand.operandSize() == 4 && (axisOperand.typeInfo().type() == ir::DataType::INT32 ||
1244 axisOperand.typeInfo().type() == ir::DataType::INT64)))
1245 throw std::runtime_error("ArgMax: `axis` with an int32 or int64 element is only supported.");
1248 template <typename LoaderDomain>
1249 void BaseLoader<LoaderDomain>::loadLogSoftmax(const Operator *op, ir::Graph &subg)
1251 ir::operation::LogSoftmax::Param param;
1252 // In tflite, beta is fixed to 1.0 and axis is fixed to -1.
1256 loadOperationTo<ir::operation::LogSoftmax>(op, subg, param);
1259 template <typename LoaderDomain>
1260 void BaseLoader<LoaderDomain>::loadLeakyRelu(const Operator *op, ir::Graph &subg)
1262 float alpha = op->builtin_options_as_LeakyReluOptions()->alpha();
1263 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LEAKY_RELU, alpha,
1267 template <typename LoaderDomain>
1268 void BaseLoader<LoaderDomain>::loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg)
1270 ir::operation::LSTM::Param param;
1271 const auto *options = op->builtin_options_as_UnidirectionalSequenceLSTMOptions();
1272 param.activation = convertActivation(options->fused_activation_function());
1273 param.cell_threshold = options->cell_clip();
1274 param.projection_threshold = options->proj_clip();
1275 param.time_major = options->time_major();
1276 // The asymmetric_quantize_inputs option is unused yet
1278 ir::OperandIndexSequence inputs;
1279 for (const std::int32_t idx : *op->inputs())
1281 inputs.append(tensorIdxToOperandIdx(idx));
1284 ir::OperandIndexSequence outputs;
1285 // loader doesn't support optional output tensor yet
1286 if (op->outputs()->size() != 1)
1288 auto builtin_code = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
1289 throw std::runtime_error(std::string("loader doesn't support optional output tensor yet for ")
1290 .append(EnumNameBuiltinOperator(builtin_code)));
1292 for (size_t i = 0; i < ir::operation::LSTM::Output::OUTPUT; ++i)
1294 // Add optional outputs
1295 outputs.append(ir::OperandIndex());
1297 outputs.append(tensorIdxToOperandIdx(op->outputs()->Get(0)));
1299 std::unique_ptr<ir::operation::LSTM> new_op(new ir::operation::LSTM(inputs, outputs, param));
1300 subg.addOperation(std::move(new_op));
1303 template <typename LoaderDomain>
1304 void BaseLoader<LoaderDomain>::loadOperation(const Operator *op, ir::Graph &subg)
1306 const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
1310 case BuiltinOperator::BuiltinOperator_ADD_N:
1311 loadOperationTo<ir::operation::AddN>(op, subg);
1313 case BuiltinOperator::BuiltinOperator_CONV_2D:
1314 loadConv2D(op, subg);
1316 case BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D:
1317 loadPool2D(op, subg, ir::operation::Pool2D::PoolType::AVG);
1319 case BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D:
1320 loadDepthwiseConv2D(op, subg);
1322 case BuiltinOperator::BuiltinOperator_TRANSPOSE_CONV:
1323 loadTransposeConv(op, subg);
1325 case BuiltinOperator::BuiltinOperator_RESHAPE:
1326 loadReshape(op, subg);
1328 case BuiltinOperator::BuiltinOperator_SOFTMAX:
1329 loadSoftmax(op, subg);
1331 case BuiltinOperator::BuiltinOperator_MAX_POOL_2D:
1332 loadPool2D(op, subg, ir::operation::Pool2D::PoolType::MAX);
1334 case BuiltinOperator::BuiltinOperator_CONCATENATION:
1335 loadConcatenation(op, subg);
1337 case BuiltinOperator::BuiltinOperator_FLOOR:
1338 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::FLOOR);
1340 case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
1343 case BuiltinOperator::BuiltinOperator_ADD:
1344 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::ADD);
1346 case BuiltinOperator::BuiltinOperator_SUB:
1347 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::SUB);
1349 case BuiltinOperator::BuiltinOperator_MUL:
1350 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::MUL);
1352 case BuiltinOperator::BuiltinOperator_DIV:
1353 loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::DIV);
1355 case BuiltinOperator::BuiltinOperator_PACK:
1358 case BuiltinOperator::BuiltinOperator_RELU:
1359 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU,
1360 ir::operation::ElementwiseActivation::infinity, 0.f);
1362 case BuiltinOperator::BuiltinOperator_RELU_N1_TO_1:
1363 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 1.f,
1366 case BuiltinOperator::BuiltinOperator_RELU6:
1367 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 6.f,
1370 case BuiltinOperator::BuiltinOperator_RESIZE_BILINEAR:
1371 loadResizeBilinear(op, subg);
1373 case BuiltinOperator::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
1374 loadResizeNearestNeighbor(op, subg);
1376 case BuiltinOperator::BuiltinOperator_RSQRT:
1377 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::RSQRT);
1379 case BuiltinOperator::BuiltinOperator_SELECT:
1380 case BuiltinOperator::BuiltinOperator_SELECT_V2:
1381 loadOperationTo<ir::operation::Select>(op, subg);
1383 case BuiltinOperator::BuiltinOperator_SQRT:
1384 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SQRT);
1386 case BuiltinOperator::BuiltinOperator_SQUARED_DIFFERENCE:
1387 loadOperationTo<ir::operation::SquaredDifference>(op, subg);
1389 case BuiltinOperator::BuiltinOperator_TANH:
1390 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::TANH, 1.f,
1393 case BuiltinOperator::BuiltinOperator_TRANSPOSE:
1394 loadOperationTo<ir::operation::Transpose>(op, subg);
1396 case BuiltinOperator::BuiltinOperator_MEAN:
1397 loadReduce(op, subg, ir::operation::Reduce::ReduceType::MEAN);
1399 case BuiltinOperator::BuiltinOperator_REDUCE_ANY:
1400 loadReduce(op, subg, ir::operation::Reduce::ReduceType::ANY);
1402 case BuiltinOperator::BuiltinOperator_REDUCE_MAX:
1403 loadReduce(op, subg, ir::operation::Reduce::ReduceType::MAX);
1405 case BuiltinOperator::BuiltinOperator_REVERSE_V2:
1406 loadOperationTo<ir::operation::Reverse>(op, subg);
1408 case BuiltinOperator::BuiltinOperator_PAD:
1409 case BuiltinOperator::BuiltinOperator_PADV2:
1410 loadOperationTo<ir::operation::Pad>(op, subg);
1412 case BuiltinOperator::BuiltinOperator_LOGISTIC:
1413 loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LOGISTIC);
1415 case BuiltinOperator::BuiltinOperator_EXP:
1416 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::EXP);
1418 case BuiltinOperator::BuiltinOperator_EXPAND_DIMS:
1419 loadOperationTo<ir::operation::ExpandDims>(op, subg);
1421 case BuiltinOperator::BuiltinOperator_GATHER:
1422 loadGather(op, subg);
1424 case BuiltinOperator::BuiltinOperator_SPACE_TO_BATCH_ND:
1425 loadOperationTo<ir::operation::SpaceToBatchND>(op, subg);
1427 case BuiltinOperator::BuiltinOperator_BATCH_TO_SPACE_ND:
1428 loadOperationTo<ir::operation::BatchToSpaceND>(op, subg);
1430 case BuiltinOperator::BuiltinOperator_SUM:
1431 loadReduce(op, subg, ir::operation::Reduce::ReduceType::SUM);
1433 case BuiltinOperator::BuiltinOperator_CUSTOM:
1434 loadCustom(op, subg);
1436 case BuiltinOperator::BuiltinOperator_SQUEEZE:
1437 loadSqueeze(op, subg);
1439 case BuiltinOperator::BuiltinOperator_PRELU:
1440 loadOperationTo<ir::operation::PReLU>(op, subg);
1442 case BuiltinOperator::BuiltinOperator_SPLIT:
1443 loadSplit(op, subg);
1445 case BuiltinOperator::BuiltinOperator_SPLIT_V:
1446 loadSplitV(op, subg);
1448 case BuiltinOperator::BuiltinOperator_SLICE:
1449 loadOperationTo<ir::operation::Slice>(op, subg);
1451 case BuiltinOperator::BuiltinOperator_STRIDED_SLICE:
1452 loadStridedSlice(op, subg);
1454 case BuiltinOperator::BuiltinOperator_UNPACK:
1455 loadUnpack(op, subg);
1457 case BuiltinOperator::BuiltinOperator_MINIMUM:
1458 loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MIN);
1460 case BuiltinOperator::BuiltinOperator_MAXIMUM:
1461 loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MAX);
1463 case BuiltinOperator::BuiltinOperator_CAST:
1464 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::CAST);
1466 case BuiltinOperator::BuiltinOperator_EQUAL:
1467 case BuiltinOperator::BuiltinOperator_NOT_EQUAL:
1468 case BuiltinOperator::BuiltinOperator_GREATER_EQUAL:
1469 case BuiltinOperator::BuiltinOperator_GREATER:
1470 case BuiltinOperator::BuiltinOperator_LESS_EQUAL:
1471 case BuiltinOperator::BuiltinOperator_LESS:
1472 loadComparison(op, subg);
1474 case BuiltinOperator::BuiltinOperator_ONE_HOT:
1475 loadOneHot(op, subg);
1477 case BuiltinOperator::BuiltinOperator_ABS:
1478 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ABS);
1480 case BuiltinOperator::BuiltinOperator_COS:
1481 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::COS);
1483 case BuiltinOperator::BuiltinOperator_SIN:
1484 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SIN);
1486 case BuiltinOperator::BuiltinOperator_SHAPE:
1487 loadOperationTo<ir::operation::Shape>(op, subg);
1489 case BuiltinOperator::BuiltinOperator_REDUCE_PROD:
1490 loadReduce(op, subg, ir::operation::Reduce::ReduceType::PROD);
1492 case BuiltinOperator::BuiltinOperator_IF:
1495 case BuiltinOperator::BuiltinOperator_WHILE:
1496 loadWhile(op, subg);
1498 case BuiltinOperator::BuiltinOperator_NEG:
1499 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::NEG);
1501 case BuiltinOperator::BuiltinOperator_ARG_MAX:
1502 loadArgMax(op, subg);
1504 case BuiltinOperator::BuiltinOperator_LOG:
1505 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOG);
1507 case BuiltinOperator::BuiltinOperator_ROUND:
1508 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ROUND);
1510 case BuiltinOperator::BuiltinOperator_POW:
1511 loadOperationTo<ir::operation::Pow>(op, subg);
1513 case BuiltinOperator::BuiltinOperator_LOGICAL_NOT:
1514 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOGICAL_NOT);
1516 case BuiltinOperator::BuiltinOperator_LOGICAL_OR:
1517 loadElementwiseBinary(op, subg,
1518 ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR);
1520 case BuiltinOperator::BuiltinOperator_FILL:
1521 loadOperationTo<ir::operation::Fill>(op, subg);
1523 case BuiltinOperator::BuiltinOperator_ZEROS_LIKE:
1524 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ZEROS_LIKE);
1526 case BuiltinOperator::BuiltinOperator_TILE:
1527 loadOperationTo<ir::operation::Tile>(op, subg);
1529 case BuiltinOperator::BuiltinOperator_RANGE:
1530 loadOperationTo<ir::operation::Range>(op, subg);
1532 case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
1533 loadBatchMatMul(op, subg);
1535 case BuiltinOperator::BuiltinOperator_LOG_SOFTMAX:
1536 loadLogSoftmax(op, subg);
1538 case BuiltinOperator::BuiltinOperator_QUANTIZE:
1539 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::QUANTIZE);
1541 case BuiltinOperator::BuiltinOperator_DEQUANTIZE:
1542 loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::DEQUANTIZE);
1544 case BuiltinOperator::BuiltinOperator_SPACE_TO_DEPTH:
1545 loadSpaceToDepth(op, subg);
1547 case BuiltinOperator::BuiltinOperator_L2_NORMALIZATION:
1548 loadOperationTo<ir::operation::L2Normalization>(op, subg);
1550 case BuiltinOperator::BuiltinOperator_LEAKY_RELU:
1551 loadLeakyRelu(op, subg);
1553 case BuiltinOperator::BuiltinOperator_RANK:
1554 loadOperationTo<ir::operation::Rank>(op, subg);
1556 case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
1557 loadUnidirectionalSequenceLSTM(op, subg);
1560 throw std::runtime_error(
1561 std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
1565 template <typename LoaderDomain> void BaseLoader<LoaderDomain>::loadModel()
1567 LoaderDomain::VerifyModelBuffer(*_verifier.get());
1568 _model = LoaderDomain::GetModel(_base);
1570 // const auto version = _model->version();
1571 // Description unused
1572 // const auto *description = _model->description();
1573 // Metabuffer unsued
1574 // const auto *metadata_buffer = _model->metadata_buffer();
1575 // Load subgraphs and map operations on subgraph
1576 const auto domain_subgraphs = _model->subgraphs();
1577 auto subgraphs = std::make_unique<ir::Subgraphs>();
1578 for (uint32_t subgraph_index = 0; subgraph_index < domain_subgraphs->size(); ++subgraph_index)
1580 auto subg = loadSubgraph((*_model->subgraphs())[subgraph_index]);
1581 subgraphs->push(ir::SubgraphIndex{subgraph_index}, std::move(subg));
1583 _subgraphs = std::move(subgraphs);
1586 } // namespace base_loader
1587 } // namespace onert
1589 #endif //__BASE_LOADER_BASE_LOADER_H__