Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / runtime / onert / frontend / circle / src / circle_loader.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "circle_loader.h"
18 #include "base_loader.h"
19 #include "circle_schema_generated.h"
20
21 namespace onert
22 {
23 namespace circle_loader
24 {
25
26 namespace
27 {
28
29 ir::Layout convertDataFormat(circle::DataFormat data_format)
30 {
31   switch (data_format)
32   {
33     case circle::DataFormat::DataFormat_CHANNELS_FIRST:
34       return ir::Layout::NCHW;
35     case circle::DataFormat::DataFormat_CHANNELS_LAST:
36       return ir::Layout::NHWC;
37     default:
38       throw std::runtime_error("Unsupported DataFormat");
39   }
40 }
41
42 struct LoaderDomain
43 {
44   using Verifier = flatbuffers::Verifier;
45   using ActivationFunctionType = circle::ActivationFunctionType;
46   using Buffer = circle::Buffer;
47   using BuiltinOperator = circle::BuiltinOperator;
48   using CustomOptionsFormat = circle::CustomOptionsFormat;
49   using Model = circle::Model;
50   using Operator = circle::Operator;
51   using Padding = circle::Padding;
52   using Pool2DOptions = circle::Pool2DOptions;
53   using Tensor = circle::Tensor;
54   using TensorType = circle::TensorType;
55   using SubGraph = circle::SubGraph;
56
57   static const char *EnumNameBuiltinOperator(BuiltinOperator e)
58   {
59     return circle::EnumNameBuiltinOperator(e);
60   }
61   static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
62   {
63     return circle::EnumNameActivationFunctionType(e);
64   }
65   static const char *EnumNameTensorType(TensorType e) { return circle::EnumNameTensorType(e); }
66   static const Model *GetModel(const void *buf) { return circle::GetModel(buf); }
67   static bool VerifyModelBuffer(Verifier &verifier) { return circle::VerifyModelBuffer(verifier); }
68 };
69
70 class CircleLoader final : public base_loader::BaseLoader<LoaderDomain, CircleLoader>
71 {
72 public:
73   using BaseLoader::BaseLoader;
74
75   bool allowOptionalInputTensor(BuiltinOperator op) override
76   {
77     switch (op)
78     {
79       case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
80       case BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED:
81         return true;
82       default:
83         return false;
84     }
85   }
86
87   std::unique_ptr<ir::Graph> loadSubgraph(const circle::SubGraph *circle_subg)
88   {
89     auto subg = std::make_unique<ir::Graph>();
90     // Load tensors
91     _tensor_to_operand.resize(circle_subg->tensors()->size());
92     for (flatbuffers::uoffset_t i = 0; i < circle_subg->tensors()->size(); ++i)
93     {
94       _tensor_to_operand[i] = loadOperand(circle_subg->tensors()->Get(i), *subg);
95     }
96     // Set inputs
97     for (const std::int32_t input_ind : *circle_subg->inputs())
98     {
99       subg->addInput(tensorIdxToOperandIdx(input_ind));
100     }
101     // Set outputs
102     for (const std::int32_t output_ind : *circle_subg->outputs())
103     {
104       subg->addOutput(tensorIdxToOperandIdx(output_ind));
105     }
106     // Create operations
107     for (const auto *op : *circle_subg->operators())
108     {
109       CircleLoader::loadOperation(op, *subg);
110     }
111
112     subg->setLayout(convertDataFormat(circle_subg->data_format()));
113
114     subg->finishBuilding();
115
116     return subg;
117   }
118
119   void loadOperation(const circle::Operator *op, ir::Graph &subg)
120   {
121     const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
122
123     switch (builtin_op)
124     {
125       case circle::BuiltinOperator::BuiltinOperator_INSTANCE_NORM:
126         loadInstanceNorm(op, subg);
127         return;
128       case circle::BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED:
129         loadBCQFullyConnected(op, subg);
130         return;
131       case circle::BuiltinOperator::BuiltinOperator_BCQ_GATHER:
132         loadBCQGather(op, subg);
133         return;
134       default:
135         BaseLoader::loadOperation(op, subg);
136         return;
137     }
138   }
139 };
140
141 } // namespace
142
143 std::unique_ptr<ir::Subgraphs> loadModel(const char *filename)
144 {
145   auto subgraphs = std::make_unique<ir::Subgraphs>();
146   CircleLoader loader(subgraphs);
147   loader.loadFromFile(filename);
148   return subgraphs;
149 }
150
151 } // namespace circle_loader
152 } // namespace onert