2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "circle_loader.h"
18 #include "base_loader.h"
19 #include "circle_schema_generated.h"
23 namespace circle_loader
29 ir::Layout convertDataFormat(circle::DataFormat data_format)
33 case circle::DataFormat::DataFormat_CHANNELS_FIRST:
34 return ir::Layout::NCHW;
35 case circle::DataFormat::DataFormat_CHANNELS_LAST:
36 return ir::Layout::NHWC;
38 throw std::runtime_error("Unsupported DataFormat");
44 using Verifier = flatbuffers::Verifier;
45 using ActivationFunctionType = circle::ActivationFunctionType;
46 using Buffer = circle::Buffer;
47 using BuiltinOperator = circle::BuiltinOperator;
48 using CustomOptionsFormat = circle::CustomOptionsFormat;
49 using Model = circle::Model;
50 using Operator = circle::Operator;
51 using Padding = circle::Padding;
52 using Pool2DOptions = circle::Pool2DOptions;
53 using Tensor = circle::Tensor;
54 using TensorType = circle::TensorType;
55 using SubGraph = circle::SubGraph;
56 using DimensionType = circle::DimensionType;
57 using SparseIndexVector = circle::SparseIndexVector;
59 static const char *EnumNameBuiltinOperator(BuiltinOperator e)
61 return circle::EnumNameBuiltinOperator(e);
63 static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
65 return circle::EnumNameActivationFunctionType(e);
67 static const char *EnumNameTensorType(TensorType e) { return circle::EnumNameTensorType(e); }
68 static const Model *GetModel(const void *buf) { return circle::GetModel(buf); }
69 static bool VerifyModelBuffer(Verifier &verifier) { return circle::VerifyModelBuffer(verifier); }
72 class CircleLoader final : public base_loader::BaseLoader<LoaderDomain, CircleLoader>
75 void loadInstanceNorm(const Operator *op, ir::Graph &subg);
76 void loadBCQFullyConnected(const Operator *op, ir::Graph &subg);
77 void loadBCQGather(const Operator *op, ir::Graph &subg);
80 using BaseLoader::BaseLoader;
82 bool allowOptionalInputTensor(BuiltinOperator op) override
86 case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
87 case BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED:
94 std::unique_ptr<ir::Graph> loadSubgraph(const circle::SubGraph *circle_subg)
96 auto subg = std::make_unique<ir::Graph>();
98 _tensor_to_operand.resize(circle_subg->tensors()->size());
99 for (flatbuffers::uoffset_t i = 0; i < circle_subg->tensors()->size(); ++i)
101 _tensor_to_operand[i] = loadOperand(circle_subg->tensors()->Get(i), *subg);
104 for (const std::int32_t input_ind : *circle_subg->inputs())
106 subg->addInput(tensorIdxToOperandIdx(input_ind),
107 _tensor_names.at(_tensor_to_operand[input_ind]));
110 for (const std::int32_t output_ind : *circle_subg->outputs())
112 subg->addOutput(tensorIdxToOperandIdx(output_ind),
113 _tensor_names.at(_tensor_to_operand[output_ind]));
116 for (const auto *op : *circle_subg->operators())
118 CircleLoader::loadOperation(op, *subg);
121 subg->setLayout(convertDataFormat(circle_subg->data_format()));
123 subg->finishBuilding();
128 void loadOperation(const circle::Operator *op, ir::Graph &subg)
130 const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
134 case circle::BuiltinOperator::BuiltinOperator_INSTANCE_NORM:
135 loadInstanceNorm(op, subg);
137 case circle::BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED:
138 loadBCQFullyConnected(op, subg);
140 case circle::BuiltinOperator::BuiltinOperator_BCQ_GATHER:
141 loadBCQGather(op, subg);
144 BaseLoader::loadOperation(op, subg);
150 void CircleLoader::loadInstanceNorm(const Operator *op, ir::Graph &subg)
152 ir::OperandIndexSequence inputs;
153 ir::OperandIndexSequence outputs;
155 loadOperationIO(op, inputs, outputs);
157 ir::operation::InstanceNorm::Param param;
158 const auto *options = op->builtin_options_as_InstanceNormOptions();
160 param.activation = convertActivation(options->fused_activation_function());
161 // Use default value 1e-5 if value of epsilon is zero
162 param.epsilon = options->epsilon() == 0.f ? 1e-5 : options->epsilon();
164 std::unique_ptr<ir::Operation> new_op(new ir::operation::InstanceNorm(inputs, outputs, param));
165 subg.addOperation(std::move(new_op));
168 void CircleLoader::loadBCQGather(const Operator *op, ir::Graph &subg)
170 ir::OperandIndexSequence inputs;
171 ir::OperandIndexSequence outputs;
173 loadOperationIO(op, inputs, outputs);
175 ir::operation::BCQGather::Param param;
176 const auto *options = op->builtin_options_as_BCQGatherOptions();
177 param.input_hidden_size = options->input_hidden_size();
178 param.axis = options->axis();
180 std::unique_ptr<ir::Operation> new_op(new ir::operation::BCQGather(inputs, outputs, param));
181 subg.addOperation(std::move(new_op));
184 void CircleLoader::loadBCQFullyConnected(const Operator *op, ir::Graph &subg)
186 ir::OperandIndexSequence inputs;
187 ir::OperandIndexSequence outputs;
189 loadOperationIO(op, inputs, outputs);
191 ir::operation::BCQFullyConnected::Param param;
192 const auto *options = op->builtin_options_as_BCQFullyConnectedOptions();
193 param.weights_hidden_size = options->weights_hidden_size();
194 param.activation = convertActivation(options->fused_activation_function());
196 std::unique_ptr<ir::Operation> new_op(
197 new ir::operation::BCQFullyConnected(inputs, outputs, param));
198 subg.addOperation(std::move(new_op));
203 std::unique_ptr<ir::Subgraphs> loadModel(const char *filename)
205 auto subgraphs = std::make_unique<ir::Subgraphs>();
206 CircleLoader loader(subgraphs);
207 loader.loadFromFile(filename);
211 std::unique_ptr<ir::Subgraphs> loadModel(uint8_t *buffer, size_t size)
213 auto subgraphs = std::make_unique<ir::Subgraphs>();
214 CircleLoader loader(subgraphs);
215 loader.loadFromBuffer(buffer, size);
219 } // namespace circle_loader