Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / frontend / tflite / src / tflite_loader.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "tflite_loader.h"
18 #include "base_loader.h"
19 #include "tflite_schema_generated.h"
20
21 namespace onert
22 {
23 namespace tflite_loader
24 {
25
26 namespace
27 {
28
29 struct LoaderDomain
30 {
31   using Verifier = flatbuffers::Verifier;
32   using ActivationFunctionType = onert_tflite::ActivationFunctionType;
33   using Buffer = onert_tflite::Buffer;
34   using BuiltinOperator = onert_tflite::BuiltinOperator;
35   using CustomOptionsFormat = onert_tflite::CustomOptionsFormat;
36   using Model = onert_tflite::Model;
37   using Operator = onert_tflite::Operator;
38   using Padding = onert_tflite::Padding;
39   using Pool2DOptions = onert_tflite::Pool2DOptions;
40   using Tensor = onert_tflite::Tensor;
41   using TensorType = onert_tflite::TensorType;
42   using SubGraph = onert_tflite::SubGraph;
43   using DimensionType = onert_tflite::DimensionType;
44   using SparseIndexVector = onert_tflite::SparseIndexVector;
45
46   static const char *EnumNameBuiltinOperator(BuiltinOperator e)
47   {
48     return onert_tflite::EnumNameBuiltinOperator(e);
49   }
50   static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
51   {
52     return onert_tflite::EnumNameActivationFunctionType(e);
53   }
54   static const char *EnumNameTensorType(TensorType e)
55   {
56     return onert_tflite::EnumNameTensorType(e);
57   }
58   static const Model *GetModel(const void *buf) { return onert_tflite::GetModel(buf); }
59   static bool VerifyModelBuffer(Verifier &verifier)
60   {
61     return onert_tflite::VerifyModelBuffer(verifier);
62   }
63 };
64
65 class TFLiteLoader final : public base_loader::BaseLoader<LoaderDomain, TFLiteLoader>
66 {
67 public:
68   using BaseLoader::BaseLoader;
69
70   bool allowOptionalInputTensor(BuiltinOperator op) override
71   {
72     switch (op)
73     {
74       case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
75         return true;
76       default:
77         return false;
78     }
79   }
80
81   std::unique_ptr<ir::Graph> loadSubgraph(const onert_tflite::SubGraph *tflite_subg)
82   {
83     auto subg = std::make_unique<ir::Graph>();
84     // Load tensors
85     _tensor_to_operand.resize(tflite_subg->tensors()->size());
86     for (flatbuffers::uoffset_t i = 0; i < tflite_subg->tensors()->size(); ++i)
87     {
88       _tensor_to_operand[i] = loadOperand(tflite_subg->tensors()->Get(i), *subg);
89     }
90     // Set inputs
91     for (const std::int32_t input_ind : *tflite_subg->inputs())
92     {
93       subg->addInput(tensorIdxToOperandIdx(input_ind),
94                      _tensor_names.at(_tensor_to_operand[input_ind]));
95     }
96     // Set outputs
97     for (const std::int32_t output_ind : *tflite_subg->outputs())
98     {
99       subg->addOutput(tensorIdxToOperandIdx(output_ind),
100                       _tensor_names.at(_tensor_to_operand[output_ind]));
101     }
102     // Create operations
103     for (const auto *op : *tflite_subg->operators())
104     {
105       loadOperation(op, *subg);
106     }
107
108     subg->finishBuilding();
109
110     return subg;
111   }
112 };
113
114 } // namespace
115
116 std::unique_ptr<ir::Subgraphs> loadModel(const char *filename)
117 {
118   auto subgraphs = std::make_unique<ir::Subgraphs>();
119   TFLiteLoader loader(subgraphs);
120   loader.loadFromFile(filename);
121   return subgraphs;
122 }
123
124 } // namespace tflite_loader
125 } // namespace onert