Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / contrib / tflite_classify / src / InferenceInterface.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "InferenceInterface.h"
18
19 using namespace tflite;
20 using namespace tflite::ops::builtin;
21
22 InferenceInterface::InferenceInterface(const std::string &model_file, const bool use_nnapi)
23   : _interpreter(nullptr), _model(nullptr), _sess(nullptr)
24 {
25   // Load model
26   StderrReporter error_reporter;
27   _model = FlatBufferModel::BuildFromFile(model_file.c_str(), &error_reporter);
28   BuiltinOpResolver resolver;
29   InterpreterBuilder builder(*_model, resolver);
30   builder(&_interpreter);
31
32   if (use_nnapi)
33   {
34     _sess = std::make_shared<nnfw::tflite::NNAPISession>(_interpreter.get());
35   }
36   else
37   {
38     _sess = std::make_shared<nnfw::tflite::InterpreterSession>(_interpreter.get());
39   }
40
41   _sess->prepare();
42 }
43
44 InferenceInterface::~InferenceInterface() { _sess->teardown(); }
45
46 void InferenceInterface::feed(const std::string &input_name, const std::vector<float> &data,
47                               const int batch, const int height, const int width, const int channel)
48 {
49   // Set input tensor
50   for (const auto &id : _interpreter->inputs())
51   {
52     if (_interpreter->tensor(id)->name == input_name)
53     {
54       assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
55       float *p = _interpreter->tensor(id)->data.f;
56
57       // TODO consider batch
58       for (int y = 0; y < height; ++y)
59       {
60         for (int x = 0; x < width; ++x)
61         {
62           for (int c = 0; c < channel; ++c)
63           {
64             *p++ = data[y * width * channel + x * channel + c];
65           }
66         }
67       }
68     }
69   }
70 }
71
72 void InferenceInterface::run(const std::string &output_name)
73 {
74   // Run model
75   _sess->run();
76 }
77
78 void InferenceInterface::fetch(const std::string &output_name, std::vector<float> &outputs)
79 {
80   // Get output tensor
81   for (const auto &id : _interpreter->outputs())
82   {
83     if (_interpreter->tensor(id)->name == output_name)
84     {
85       assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
86       assert(getTensorSize(output_name) == outputs.capacity());
87       float *p = _interpreter->tensor(id)->data.f;
88
89       outputs.clear();
90       for (int i = 0; i < outputs.capacity(); ++i)
91       {
92         outputs.push_back(p[i]);
93       }
94     }
95   }
96 }
97
98 int InferenceInterface::getTensorSize(const std::string &name)
99 {
100   for (const auto &id : _interpreter->outputs())
101   {
102     if (_interpreter->tensor(id)->name == name)
103     {
104       TfLiteTensor *t = _interpreter->tensor(id);
105       int v = 1;
106       for (int i = 0; i < t->dims->size; ++i)
107       {
108         v *= t->dims->data[i];
109       }
110       return v;
111     }
112   }
113   return -1;
114 }