2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "InferenceInterface.h"
19 using namespace tflite;
20 using namespace tflite::ops::builtin;
22 InferenceInterface::InferenceInterface(const std::string &model_file, const bool use_nnapi)
23 : _interpreter(nullptr), _model(nullptr), _sess(nullptr)
26 StderrReporter error_reporter;
27 _model = FlatBufferModel::BuildFromFile(model_file.c_str(), &error_reporter);
28 BuiltinOpResolver resolver;
29 InterpreterBuilder builder(*_model, resolver);
30 builder(&_interpreter);
34 _sess = std::make_shared<nnfw::tflite::NNAPISession>(_interpreter.get());
38 _sess = std::make_shared<nnfw::tflite::InterpreterSession>(_interpreter.get());
44 InferenceInterface::~InferenceInterface() { _sess->teardown(); }
46 void InferenceInterface::feed(const std::string &input_name, const std::vector<float> &data,
47 const int batch, const int height, const int width, const int channel)
50 for (const auto &id : _interpreter->inputs())
52 if (_interpreter->tensor(id)->name == input_name)
54 assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
55 float *p = _interpreter->tensor(id)->data.f;
57 // TODO consider batch
58 for (int y = 0; y < height; ++y)
60 for (int x = 0; x < width; ++x)
62 for (int c = 0; c < channel; ++c)
64 *p++ = data[y * width * channel + x * channel + c];
72 void InferenceInterface::run(const std::string &output_name)
78 void InferenceInterface::fetch(const std::string &output_name, std::vector<float> &outputs)
81 for (const auto &id : _interpreter->outputs())
83 if (_interpreter->tensor(id)->name == output_name)
85 assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
86 assert(getTensorSize(output_name) == outputs.capacity());
87 float *p = _interpreter->tensor(id)->data.f;
90 for (int i = 0; i < outputs.capacity(); ++i)
92 outputs.push_back(p[i]);
98 int InferenceInterface::getTensorSize(const std::string &name)
100 for (const auto &id : _interpreter->outputs())
102 if (_interpreter->tensor(id)->name == name)
104 TfLiteTensor *t = _interpreter->tensor(id);
106 for (int i = 0; i < t->dims->size; ++i)
108 v *= t->dims->data[i];