Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / contrib / android_benchmark_app / cpp / ndk_main.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the License);
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an AS IS BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ndk_main.h"
18
19 #include "tensorflow/lite/kernels/register.h"
20 #include "tensorflow/lite/model.h"
21
22 #include "tflite/Assert.h"
23 #include "tflite/Session.h"
24 #include "tflite/InterpreterSession.h"
25 #include "tflite/NNAPISession.h"
26 #include "tflite/ext/kernels/register.h"
27
28 #include "misc/benchmark.h"
29
30 #include <boost/accumulators/accumulators.hpp>
31 #include <boost/accumulators/statistics/stats.hpp>
32 #include <boost/accumulators/statistics/mean.hpp>
33 #include <boost/accumulators/statistics/min.hpp>
34 #include <boost/accumulators/statistics/max.hpp>
35
36 #include <cassert>
37 #include <chrono>
38 #include <sstream>
39
40 #include <android/log.h>
41
42 using namespace tflite;
43 using namespace tflite::ops::builtin;
44
45 static StderrReporter error_reporter;
46
47 static std::unique_ptr<FlatBufferModel> model;
48
49 inline void setText(JNIEnv *env, jobject thisObj, const std::string &message)
50 {
51   jclass thisClass = env->GetObjectClass(thisObj);
52   jmethodID setTextMethod = env->GetMethodID(thisClass, "setText", "(Ljava/lang/String;)V");
53
54   assert(setTextMethod != nullptr);
55
56   env->CallVoidMethod(thisObj, setTextMethod, env->NewStringUTF(message.c_str()));
57 }
58
59 inline void setTitle(JNIEnv *env, jobject thisObj, const std::string &message)
60 {
61   jclass thisClass = env->GetObjectClass(thisObj);
62   jmethodID setTextMethod = env->GetMethodID(thisClass, "setTitle", "(Ljava/lang/String;)V");
63
64   assert(setTextMethod != nullptr);
65
66   env->CallVoidMethod(thisObj, setTextMethod, env->NewStringUTF(message.c_str()));
67
68   // Clear message
69   setText(env, thisObj, "");
70 }
71
72 inline void setText(JNIEnv *env, jobject thisObj, const std::stringstream &ss)
73 {
74   setText(env, thisObj, ss.str());
75 }
76
77 inline std::unique_ptr<FlatBufferModel> loadModel(JNIEnv *env, jobject thisObj,
78                                                   jobject model_buffer)
79 {
80   const char *model_base = static_cast<char *>(env->GetDirectBufferAddress(model_buffer));
81   jlong model_size = env->GetDirectBufferCapacity(model_buffer);
82
83   return FlatBufferModel::BuildFromBuffer(model_base, static_cast<size_t>(model_size),
84                                           &error_reporter);
85 }
86
87 struct Activity
88 {
89   virtual ~Activity() = default;
90
91   virtual void prepare(void) const = 0;
92   virtual void run(void) const = 0;
93   virtual void teardown(void) const = 0;
94 };
95
96 struct LiteActivity final : public Activity
97 {
98 public:
99   LiteActivity(nnfw::tflite::Session &sess) : _sess(sess)
100   {
101     // DO NOTHING
102   }
103
104 public:
105   void prepare(void) const override { _sess.prepare(); }
106   void run(void) const override { _sess.run(); }
107   void teardown(void) const override { _sess.teardown(); }
108
109 private:
110   nnfw::tflite::Session &_sess;
111 };
112
113 struct SimpleActivity final : public Activity
114 {
115 public:
116   SimpleActivity(const std::function<void(void)> &fn) : _fn{fn}
117   {
118     // DO NOTHING
119   }
120
121 public:
122   void prepare(void) const override {}
123   void run(void) const override { _fn(); }
124   void teardown(void) const override {}
125
126 private:
127   std::function<void(void)> _fn;
128 };
129
130 inline void runBenchmark(JNIEnv *env, jobject thisObj, Activity &act)
131 {
132   auto runTrial = [&](void) {
133     std::chrono::milliseconds elapsed(0);
134
135     act.prepare();
136     nnfw::misc::benchmark::measure(elapsed) << [&](void) { act.run(); };
137     act.teardown();
138
139     return elapsed;
140   };
141
142   // Warm-up
143   for (uint32_t n = 0; n < 3; ++n)
144   {
145     auto elapsed = runTrial();
146
147     std::stringstream ss;
148     ss << "Warm-up #" << n << "  takes " << elapsed.count() << "ms" << std::endl;
149     setText(env, thisObj, ss);
150   }
151
152   // Measure
153   using namespace boost::accumulators;
154
155   accumulator_set<double, stats<tag::mean, tag::min, tag::max>> acc;
156
157   for (uint32_t n = 0; n < 100; ++n)
158   {
159     auto elapsed = runTrial();
160
161     std::stringstream ss;
162     ss << "Iteration #" << n << " takes " << elapsed.count() << "ms" << std::endl;
163     setText(env, thisObj, ss);
164
165     acc(elapsed.count());
166   }
167
168   std::stringstream ss;
169   ss << "Average is " << mean(acc) << "ms" << std::endl;
170   ss << "Min is " << min(acc) << "ms" << std::endl;
171   ss << "Max is " << max(acc) << "ms" << std::endl;
172   setText(env, thisObj, ss);
173 }
174
175 JNIEXPORT void JNICALL Java_com_ndk_tflbench_MainActivity_runInterpreterBenchmark(
176   JNIEnv *env, jobject thisObj, jobject model_buffer)
177 {
178   setTitle(env, thisObj, "Running Interpreter Benchmark");
179
180   auto model = loadModel(env, thisObj, model_buffer);
181   assert(model != nullptr);
182
183   nnfw::tflite::BuiltinOpResolver resolver;
184   InterpreterBuilder builder(*model, resolver);
185
186   std::unique_ptr<Interpreter> interpreter;
187
188   TFLITE_ENSURE(builder(&interpreter));
189
190   interpreter->SetNumThreads(-1);
191
192   nnfw::tflite::InterpreterSession sess(interpreter.get());
193   LiteActivity act{sess};
194   runBenchmark(env, thisObj, act);
195 }
196
197 static void runNNAPIBenchmark(JNIEnv *env, jobject thisObj, jobject model_buffer)
198 {
199   auto model = loadModel(env, thisObj, model_buffer);
200   assert(model != nullptr);
201
202   nnfw::tflite::BuiltinOpResolver resolver;
203   InterpreterBuilder builder(*model, resolver);
204
205   std::unique_ptr<Interpreter> interpreter;
206
207   TFLITE_ENSURE(builder(&interpreter));
208
209   nnfw::tflite::NNAPISession sess(interpreter.get());
210   LiteActivity act{sess};
211   runBenchmark(env, thisObj, act);
212 }
213
214 JNIEXPORT void JNICALL Java_com_ndk_tflbench_MainActivity_runNNAPIBenchmark(JNIEnv *env,
215                                                                             jobject thisObj,
216                                                                             jobject model_buffer)
217 {
218   setTitle(env, thisObj, "Running NNAPI Benchmark");
219
220   try
221   {
222     runNNAPIBenchmark(env, thisObj, model_buffer);
223   }
224   catch (const std::exception &ex)
225   {
226     std::stringstream ss;
227     ss << "Caught an exception " << ex.what();
228     setText(env, thisObj, ss);
229   }
230 }
231
232 JNIEXPORT jstring JNICALL Java_com_ndk_tflbench_MainActivity_getModelName(JNIEnv *env,
233                                                                           jobject thisObj)
234 {
235   return env->NewStringUTF(MODEL_NAME);
236 }
237
238 #define TF_ENSURE(e)                               \
239   {                                                \
240     if (!(e).ok())                                 \
241     {                                              \
242       throw std::runtime_error{"'" #e "' FAILED"}; \
243     }                                              \
244   }