#include "operators.h"
#include <iostream>
+#include <chrono>
using namespace tflite;
using namespace tflite::ops::builtin;
+// Benckmark support
+namespace benchmark
+{
+
+class ElapsedTime;
+
+class Stopwatch
+{
+public:
+ Stopwatch(ElapsedTime *buffer) : _buffer(buffer)
+ {
+ // DO NOTHING
+ }
+
+public:
+ ElapsedTime *buffer(void) { return _buffer; }
+
+private:
+ ElapsedTime *_buffer;
+};
+
+class ElapsedTime
+{
+public:
+ double count(void) const { return _elapsed.count(); }
+
+public:
+ ElapsedTime &add(const std::chrono::duration<double> &elapsed)
+ {
+ _elapsed += elapsed;
+ return (*this);
+ }
+
+public:
+ Stopwatch measure(void) { return Stopwatch(this); }
+
+private:
+ std::chrono::duration<double> _elapsed;
+};
+
+template <typename Callable> Stopwatch &operator<<(Stopwatch &&sw, Callable cb)
+{
+ using namespace std::chrono;
+
+ auto begin = steady_clock::now();
+ cb();
+ auto end = steady_clock::now();
+
+ sw.buffer()->add(duration_cast<duration<double>>(end - begin));
+
+ return sw;
+}
+
+} // namespace benchmark
+
int main(int argc, char **argv)
{
const auto filename = argv[1];
auto model = FlatBufferModel::BuildFromFile(filename, &error_reporter);
- BuiltinOpResolver resolver;
+ std::unique_ptr<Interpreter> interpreter;
+
+ TfLiteStatus status = kTfLiteError;
+
+ benchmark::ElapsedTime t_prepare;
+ benchmark::ElapsedTime t_invoke;
+
+ t_prepare.measure() << [&](void)
+ {
+ BuiltinOpResolver resolver;
#define REGISTER(Name) { resolver.AddCustom(#Name, Register_##Name()); }
- REGISTER(CAST);
- REGISTER(Stack);
- REGISTER(ArgMax);
- REGISTER(TensorFlowMax);
+ REGISTER(CAST);
+ REGISTER(Stack);
+ REGISTER(ArgMax);
+ REGISTER(TensorFlowMax);
#undef REGISTER
- InterpreterBuilder builder(*model, resolver);
+ InterpreterBuilder builder(*model, resolver);
- std::unique_ptr<Interpreter> interpreter;
-
- TfLiteStatus status = kTfLiteError;
+ status = builder(&interpreter);
+ assert(status == kTfLiteOk);
- status = builder(&interpreter);
- assert(status == kTfLiteOk);
+ interpreter->UseNNAPI(use_nnapi);
+ interpreter->SetNumThreads(1);
- interpreter->UseNNAPI(use_nnapi);
- interpreter->SetNumThreads(1);
+ status = interpreter->AllocateTensors();
+ assert(status == kTfLiteOk);
+ };
- status = interpreter->AllocateTensors();
- assert(status == kTfLiteOk);
+ t_invoke.measure() << [&status, &interpreter](void)
+ {
+ status = interpreter->Invoke();
+ assert(status == kTfLiteOk);
+ };
- status = interpreter->Invoke();
- assert(status == kTfLiteOk);
+ std::cout << "Prepare takes " << t_prepare.count() << " seconds" << std::endl;
+ std::cout << "Invoke takes " << t_invoke.count() << " seconds" << std::endl;
return status;
}