2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the License);
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an AS IS BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include "tensorflow/lite/kernels/register.h"
20 #include "tensorflow/lite/model.h"
22 #include "tflite/Assert.h"
23 #include "tflite/Session.h"
24 #include "tflite/InterpreterSession.h"
25 #include "tflite/NNAPISession.h"
26 #include "tflite/ext/kernels/register.h"
28 #include "misc/benchmark.h"
30 #include <boost/accumulators/accumulators.hpp>
31 #include <boost/accumulators/statistics/stats.hpp>
32 #include <boost/accumulators/statistics/mean.hpp>
33 #include <boost/accumulators/statistics/min.hpp>
34 #include <boost/accumulators/statistics/max.hpp>
40 #include <android/log.h>
42 using namespace tflite;
43 using namespace tflite::ops::builtin;
45 static StderrReporter error_reporter;
47 static std::unique_ptr<FlatBufferModel> model;
49 inline void setText(JNIEnv *env, jobject thisObj, const std::string &message)
51 jclass thisClass = env->GetObjectClass(thisObj);
52 jmethodID setTextMethod = env->GetMethodID(thisClass, "setText", "(Ljava/lang/String;)V");
54 assert(setTextMethod != nullptr);
56 env->CallVoidMethod(thisObj, setTextMethod, env->NewStringUTF(message.c_str()));
59 inline void setTitle(JNIEnv *env, jobject thisObj, const std::string &message)
61 jclass thisClass = env->GetObjectClass(thisObj);
62 jmethodID setTextMethod = env->GetMethodID(thisClass, "setTitle", "(Ljava/lang/String;)V");
64 assert(setTextMethod != nullptr);
66 env->CallVoidMethod(thisObj, setTextMethod, env->NewStringUTF(message.c_str()));
69 setText(env, thisObj, "");
72 inline void setText(JNIEnv *env, jobject thisObj, const std::stringstream &ss)
74 setText(env, thisObj, ss.str());
77 inline std::unique_ptr<FlatBufferModel> loadModel(JNIEnv *env, jobject thisObj,
80 const char *model_base = static_cast<char *>(env->GetDirectBufferAddress(model_buffer));
81 jlong model_size = env->GetDirectBufferCapacity(model_buffer);
83 return FlatBufferModel::BuildFromBuffer(model_base, static_cast<size_t>(model_size),
89 virtual ~Activity() = default;
91 virtual void prepare(void) const = 0;
92 virtual void run(void) const = 0;
93 virtual void teardown(void) const = 0;
96 struct LiteActivity final : public Activity
99 LiteActivity(nnfw::tflite::Session &sess) : _sess(sess)
105 void prepare(void) const override { _sess.prepare(); }
106 void run(void) const override { _sess.run(); }
107 void teardown(void) const override { _sess.teardown(); }
110 nnfw::tflite::Session &_sess;
113 struct SimpleActivity final : public Activity
116 SimpleActivity(const std::function<void(void)> &fn) : _fn{fn}
122 void prepare(void) const override {}
123 void run(void) const override { _fn(); }
124 void teardown(void) const override {}
127 std::function<void(void)> _fn;
130 inline void runBenchmark(JNIEnv *env, jobject thisObj, Activity &act)
132 auto runTrial = [&](void) {
133 std::chrono::milliseconds elapsed(0);
136 nnfw::misc::benchmark::measure(elapsed) << [&](void) { act.run(); };
143 for (uint32_t n = 0; n < 3; ++n)
145 auto elapsed = runTrial();
147 std::stringstream ss;
148 ss << "Warm-up #" << n << " takes " << elapsed.count() << "ms" << std::endl;
149 setText(env, thisObj, ss);
153 using namespace boost::accumulators;
155 accumulator_set<double, stats<tag::mean, tag::min, tag::max>> acc;
157 for (uint32_t n = 0; n < 100; ++n)
159 auto elapsed = runTrial();
161 std::stringstream ss;
162 ss << "Iteration #" << n << " takes " << elapsed.count() << "ms" << std::endl;
163 setText(env, thisObj, ss);
165 acc(elapsed.count());
168 std::stringstream ss;
169 ss << "Average is " << mean(acc) << "ms" << std::endl;
170 ss << "Min is " << min(acc) << "ms" << std::endl;
171 ss << "Max is " << max(acc) << "ms" << std::endl;
172 setText(env, thisObj, ss);
175 JNIEXPORT void JNICALL Java_com_ndk_tflbench_MainActivity_runInterpreterBenchmark(
176 JNIEnv *env, jobject thisObj, jobject model_buffer)
178 setTitle(env, thisObj, "Running Interpreter Benchmark");
180 auto model = loadModel(env, thisObj, model_buffer);
181 assert(model != nullptr);
183 nnfw::tflite::BuiltinOpResolver resolver;
184 InterpreterBuilder builder(*model, resolver);
186 std::unique_ptr<Interpreter> interpreter;
188 TFLITE_ENSURE(builder(&interpreter));
190 interpreter->SetNumThreads(-1);
192 nnfw::tflite::InterpreterSession sess(interpreter.get());
193 LiteActivity act{sess};
194 runBenchmark(env, thisObj, act);
197 static void runNNAPIBenchmark(JNIEnv *env, jobject thisObj, jobject model_buffer)
199 auto model = loadModel(env, thisObj, model_buffer);
200 assert(model != nullptr);
202 nnfw::tflite::BuiltinOpResolver resolver;
203 InterpreterBuilder builder(*model, resolver);
205 std::unique_ptr<Interpreter> interpreter;
207 TFLITE_ENSURE(builder(&interpreter));
209 nnfw::tflite::NNAPISession sess(interpreter.get());
210 LiteActivity act{sess};
211 runBenchmark(env, thisObj, act);
214 JNIEXPORT void JNICALL Java_com_ndk_tflbench_MainActivity_runNNAPIBenchmark(JNIEnv *env,
216 jobject model_buffer)
218 setTitle(env, thisObj, "Running NNAPI Benchmark");
222 runNNAPIBenchmark(env, thisObj, model_buffer);
224 catch (const std::exception &ex)
226 std::stringstream ss;
227 ss << "Caught an exception " << ex.what();
228 setText(env, thisObj, ss);
232 JNIEXPORT jstring JNICALL Java_com_ndk_tflbench_MainActivity_getModelName(JNIEnv *env,
235 return env->NewStringUTF(MODEL_NAME);
238 #define TF_ENSURE(e) \
242 throw std::runtime_error{"'" #e "' FAILED"}; \