2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "tflite_loader.h"
18 #include "base_loader.h"
19 #include "tflite_schema_generated.h"
23 namespace tflite_loader
31 using Verifier = flatbuffers::Verifier;
32 using ActivationFunctionType = onert_tflite::ActivationFunctionType;
33 using Buffer = onert_tflite::Buffer;
34 using BuiltinOperator = onert_tflite::BuiltinOperator;
35 using CustomOptionsFormat = onert_tflite::CustomOptionsFormat;
36 using Model = onert_tflite::Model;
37 using Operator = onert_tflite::Operator;
38 using Padding = onert_tflite::Padding;
39 using Pool2DOptions = onert_tflite::Pool2DOptions;
40 using Tensor = onert_tflite::Tensor;
41 using TensorType = onert_tflite::TensorType;
42 using SubGraph = onert_tflite::SubGraph;
43 using DimensionType = onert_tflite::DimensionType;
44 using SparseIndexVector = onert_tflite::SparseIndexVector;
46 static const char *EnumNameBuiltinOperator(BuiltinOperator e)
48 return onert_tflite::EnumNameBuiltinOperator(e);
50 static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
52 return onert_tflite::EnumNameActivationFunctionType(e);
54 static const char *EnumNameTensorType(TensorType e)
56 return onert_tflite::EnumNameTensorType(e);
58 static const Model *GetModel(const void *buf) { return onert_tflite::GetModel(buf); }
59 static bool VerifyModelBuffer(Verifier &verifier)
61 return onert_tflite::VerifyModelBuffer(verifier);
65 class TFLiteLoader final : public base_loader::BaseLoader<LoaderDomain, TFLiteLoader>
68 using BaseLoader::BaseLoader;
70 bool allowOptionalInputTensor(BuiltinOperator op) override
74 case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
81 std::unique_ptr<ir::Graph> loadSubgraph(const onert_tflite::SubGraph *tflite_subg)
83 auto subg = std::make_unique<ir::Graph>();
85 _tensor_to_operand.resize(tflite_subg->tensors()->size());
86 for (flatbuffers::uoffset_t i = 0; i < tflite_subg->tensors()->size(); ++i)
88 _tensor_to_operand[i] = loadOperand(tflite_subg->tensors()->Get(i), *subg);
91 for (const std::int32_t input_ind : *tflite_subg->inputs())
93 subg->addInput(tensorIdxToOperandIdx(input_ind),
94 _tensor_names.at(_tensor_to_operand[input_ind]));
97 for (const std::int32_t output_ind : *tflite_subg->outputs())
99 subg->addOutput(tensorIdxToOperandIdx(output_ind),
100 _tensor_names.at(_tensor_to_operand[output_ind]));
103 for (const auto *op : *tflite_subg->operators())
105 loadOperation(op, *subg);
108 subg->finishBuilding();
116 std::unique_ptr<ir::Subgraphs> loadModel(const char *filename)
118 auto subgraphs = std::make_unique<ir::Subgraphs>();
119 TFLiteLoader loader(subgraphs);
120 loader.loadFromFile(filename);
124 } // namespace tflite_loader