tflite_benchmark_model is updated to v1.12.0. (#3660)
author이상규/동작제어Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Mon, 26 Nov 2018 05:56:12 +0000 (14:56 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 26 Nov 2018 05:56:12 +0000 (14:56 +0900)
Most files are not changed from v1.12.0.
My modification is to support multiple kernel expansion operators.
You can find the changes from stats_calculator.cc and profile_summarizer.cc.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
19 files changed:
externals/CMakeLists.txt
include/util/profiling/time.h
libs/util/src/profiling/.FORMATDENY [new file with mode: 0644]
libs/util/src/profiling/time.cc
tools/tflite_benchmark_model/CMakeLists.txt
tools/tflite_benchmark_model/README.md
tools/tflite_benchmark_model/benchmark_main.cc [deleted file]
tools/tflite_benchmark_model/benchmark_model.cc [deleted file]
tools/tflite_benchmark_model/benchmark_model.h [deleted file]
tools/tflite_benchmark_model/benchmark_params.cc [deleted file]
tools/tflite_benchmark_model/benchmark_params.h [deleted file]
tools/tflite_benchmark_model/benchmark_tflite_model.cc
tools/tflite_benchmark_model/benchmark_tflite_model.h [deleted file]
tools/tflite_benchmark_model/command_line_flags.cc [deleted file]
tools/tflite_benchmark_model/command_line_flags.h [deleted file]
tools/tflite_benchmark_model/logging.h [deleted file]
tools/tflite_benchmark_model/profile_summarizer.cc
tools/tflite_benchmark_model/profile_summarizer.h [deleted file]
tools/tflite_benchmark_model/stats_calculator.cc [new file with mode: 0644]

index c6f5b5b..65c9cd9 100644 (file)
@@ -37,15 +37,6 @@ list(APPEND TFLITE_SRCS ${TFLITE_API_SRCS})
 
 list(APPEND TFLITE_SRCS "${TFLITE_DEPEND_DIR}/farmhash/src/farmhash.cc")
 
-# Profiling
-if(BUILD_TFLITE_BENCHMARK_MODEL)
-  file(GLOB TFLITE_PROFILING_SRCS "${TENSORFLOW_LITE_BASE}/profiling/*.cc")
-  file(GLOB TFLITE_PROFILING_TESTS "${TENSORFLOW_LITE_BASE}/profiling/*test*.cc")
-  list(REMOVE_ITEM TFLITE_PROFILING_SRCS ${TFLITE_PROFILING_TESTS})
-  list(APPEND TFLITE_PROFILING_SRCS "${TENSORFLOW_BASE}/tensorflow/core/util/stats_calculator.cc")
-  list(APPEND TFLITE_SRCS ${TFLITE_PROFILING_SRCS})
-endif()
-
 list(APPEND TFLITE_INCLUDES "${CMAKE_CURRENT_SOURCE_DIR}/tensorflow")
 list(APPEND TFLITE_INCLUDES "${TFLITE_DEPEND_DIR}/absl")
 list(APPEND TFLITE_INCLUDES "${TFLITE_DEPEND_DIR}/gemmlowp")
index da32893..cc2ec31 100644 (file)
@@ -1,19 +1,3 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/libs/util/src/profiling/.FORMATDENY b/libs/util/src/profiling/.FORMATDENY
new file mode 100644 (file)
index 0000000..e69de29
index 6fe1b54..6841c71 100644 (file)
@@ -1,19 +1,3 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -30,20 +14,34 @@ limitations under the License.
 ==============================================================================*/
 #include "util/profiling/time.h"
 
+#if defined(_MSC_VER)
+#include <chrono>  // NOLINT(build/c++11)
+#else
 #include <sys/time.h>
+#endif
+
+namespace tflite {
+namespace profiling {
+namespace time {
+
+#if defined(_MSC_VER)
+
+uint64_t NowMicros() {
+  return std::chrono::duration_cast<std::chrono::microseconds>(
+             std::chrono::system_clock::now().time_since_epoch())
+      .count();
+}
+
+#else
 
-namespace tflite
-{
-namespace profiling
-{
-namespace time
-{
-uint64_t NowMicros()
-{
+uint64_t NowMicros() {
   struct timeval tv;
   gettimeofday(&tv, nullptr);
   return static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
 }
-} // namespace time
-} // namespace profiling
-} // namespace tflite
+
+#endif  // defined(_MSC_VER)
+
+}  // namespace time
+}  // namespace profiling
+}  // namespace tflite
index d526904..dd54dc5 100644 (file)
@@ -1,5 +1,12 @@
 file(GLOB_RECURSE SOURCES "*.cc")
 
+nnfw_find_package(TensorFlowSource REQUIRED)
+set(TENSORFLOW_LITE_BASE "${TensorFlowSource_DIR}/tensorflow/contrib/lite")
+list(APPEND SOURCES "${TENSORFLOW_LITE_BASE}/tools/benchmark/benchmark_main.cc"
+                    "${TENSORFLOW_LITE_BASE}/tools/benchmark/benchmark_model.cc"
+                    "${TENSORFLOW_LITE_BASE}/tools/benchmark/benchmark_params.cc"
+                    "${TENSORFLOW_LITE_BASE}/tools/benchmark/command_line_flags.cc")
+
 add_executable(tflite_benchmark_model ${SOURCES})
 target_compile_definitions(tflite_benchmark_model PUBLIC "TFLITE_PROFILING_ENABLED")
 target_link_libraries(tflite_benchmark_model tensorflow-lite ${LIB_PTHREAD} dl nnfw_util nnfw_support_tflite)
index 9376930..8d99763 100644 (file)
@@ -9,7 +9,7 @@ of runs. Aggregrate latency statistics are reported after running the benchmark.
 
 The instructions below are for running the binary on Desktop and Android,
 for iOS please use the
-[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
 
 ## Parameters
 
@@ -17,11 +17,6 @@ The binary takes the following required parameters:
 
 *   `graph`: `string` \
     The path to the TFLite model file.
-*   `input_layer`: `string` \
-    The name of the input layer, this is typically the first layer of the model.
-*   `input_layer_shape`: `string` \
-    The shape of the input layer. This is a comma separated string of the shape
-    of tensor of input layer.
 
 and the following optional parameters:
 
@@ -29,11 +24,13 @@ and the following optional parameters:
     The number of threads to use for running TFLite interpreter.
 *   `warmup_runs`: `int` (default=1) \
     The number of warmup runs to do before starting the benchmark.
+*   `num_runs`: `int` (default=50) \
+    The number of runs. Increase this to reduce variance.
 *   `run_delay`: `float` (default=-1.0) \
     The delay in seconds between subsequent benchmark runs. Non-positive values
     mean use no delay.
 *   `use_nnapi`: `bool` (default=false) \
-    Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/).
+    Whether to use [Android NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/).
     This API is available on recent Android devices.
 
 ## To build/install/run
@@ -75,8 +72,6 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp
 ```
 adb shell /data/local/tmp/benchmark_model \
   --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
-  --input_layer="input" \
-  --input_layer_shape="1,224,224,3" \
   --num_threads=4
 ```
 
@@ -93,13 +88,10 @@ For example:
 ```
 bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
   --graph=mobilenet_quant_v1_224.tflite \
-  --input_layer="Placeholder" \
-  --input_layer_shape="1,224,224,3" \
   --num_threads=4
 ```
 
-The MobileNet graph used as an example here may be downloaded from
-https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+The MobileNet graph used as an example here may be downloaded from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip).
 
 
 ## Reducing variance between runs on Android.
@@ -115,10 +107,8 @@ E.g. for running the benchmark on big cores on Pixel 2 with a single thread one
 can use the following command:
 
 ```
-adb shell tasket f0 /data/local/tmp/benchmark_model \
+adb shell taskset f0 /data/local/tmp/benchmark_model \
   --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
-  --input_layer="input" \
-  --input_layer_shape="1,224,224,3" \
   --num_threads=1
 ```
 
@@ -205,5 +195,3 @@ Memory (bytes): count=0
 
 Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
 ```
-
-
diff --git a/tools/tflite_benchmark_model/benchmark_main.cc b/tools/tflite_benchmark_model/benchmark_main.cc
deleted file mode 100644 (file)
index 7e4231c..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "benchmark_tflite_model.h"
-#include "logging.h"
-
-namespace nnfw {
-namespace benchmark {
-
-int Main(int argc, char** argv) {
-#ifdef TFLITE_CUSTOM_OPS_HEADER
-  TFLITE_LOG(INFO) << "STARTING with custom ops!";
-#else
-  TFLITE_LOG(INFO) << "STARTING!";
-#endif
-  BenchmarkTfLiteModel benchmark;
-  BenchmarkLoggingListener listener;
-  benchmark.AddListener(&listener);
-  benchmark.Run(argc, argv);
-  return 0;
-}
-}  // namespace benchmark
-}  // namespace nnfw
-
-int main(int argc, char** argv) { return nnfw::benchmark::Main(argc, argv); }
diff --git a/tools/tflite_benchmark_model/benchmark_model.cc b/tools/tflite_benchmark_model/benchmark_model.cc
deleted file mode 100644 (file)
index 7869180..0000000
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "benchmark_model.h"
-
-#include <time.h>
-
-#include <iostream>
-#include <sstream>
-
-#include "tensorflow/contrib/lite/profiling/time.h"
-#include "logging.h"
-
-namespace {
-void SleepForSeconds(double sleep_seconds) {
-  if (sleep_seconds <= 0.0) {
-    return;
-  }
-  // Convert the run_delay string into a timespec.
-  timespec req;
-  req.tv_sec = static_cast<time_t>(sleep_seconds);
-  req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000;
-  // If requested, sleep between runs for an arbitrary amount of time.
-  // This can be helpful to determine the effect of mobile processor
-  // scaling and thermal throttling.
-#ifdef PLATFORM_WINDOWS
-  Sleep(sleep_seconds * 1000);
-#else
-  nanosleep(&req, nullptr);
-#endif
-}
-
-}  // namespace
-
-namespace nnfw {
-namespace benchmark {
-using tensorflow::Stat;
-
-BenchmarkParams BenchmarkModel::DefaultParams() {
-  BenchmarkParams params;
-  params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50));
-  params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
-  params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
-  params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
-  params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
-  params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
-  return params;
-}
-
-BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {}
-
-void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
-  auto inference_us = results.inference_time_us();
-  auto init_us = results.startup_latency_us();
-  auto warmup_us = results.warmup_time_us();
-  TFLITE_LOG(INFO) << "Average inference timings in us: "
-                   << "Warmup: " << warmup_us.avg() << ", "
-                   << "Init: " << init_us << ", "
-                   << "no stats: " << inference_us.avg();
-}
-
-std::vector<Flag> BenchmarkModel::GetFlags() {
-  return {
-      CreateFlag<int32_t>("num_runs", &params_, "number of runs"),
-      CreateFlag<float>("run_delay", &params_, "delay between runs in seconds"),
-      CreateFlag<int32_t>("num_threads", &params_, "number of threads"),
-      CreateFlag<std::string>("benchmark_name", &params_, "benchmark name"),
-      CreateFlag<std::string>("output_prefix", &params_,
-                              "benchmark output prefix"),
-      CreateFlag<int32_t>("warmup_runs", &params_,
-                          "how many runs to initialize model"),
-  };
-}
-
-void BenchmarkModel::LogFlags() {
-  TFLITE_LOG(INFO) << "Num runs: [" << params_.Get<int32_t>("num_runs") << "]";
-  TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
-                   << params_.Get<float>("run_delay") << "]";
-  TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
-                   << "]";
-  TFLITE_LOG(INFO) << "Benchmark name: ["
-                   << params_.Get<std::string>("benchmark_name") << "]";
-  TFLITE_LOG(INFO) << "Output prefix: ["
-                   << params_.Get<std::string>("output_prefix") << "]";
-  TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get<int32_t>("warmup_runs")
-                   << "]";
-}
-
-Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
-  Stat<int64_t> run_stats;
-  TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations ";
-  for (int run = 0; run < num_times; run++) {
-    listeners_.OnSingleRunStart(run_type);
-    int64_t start_us = tflite::profiling::time::NowMicros();
-    RunImpl();
-    int64_t end_us = tflite::profiling::time::NowMicros();
-    listeners_.OnSingleRunEnd();
-
-    run_stats.UpdateStat(end_us - start_us);
-    SleepForSeconds(params_.Get<float>("run_delay"));
-  }
-
-  std::stringstream stream;
-  run_stats.OutputToStream(&stream);
-  TFLITE_LOG(INFO) << stream.str() << std::endl;
-
-  return run_stats;
-}
-
-void BenchmarkModel::Run(int argc, char **argv) {
-  if (!ParseFlags(argc, argv)) {
-    return;
-  }
-
-  LogFlags();
-
-  listeners_.OnBenchmarkStart(params_);
-  int64_t initialization_start_us = tflite::profiling::time::NowMicros();
-  Init();
-  int64_t initialization_end_us = tflite::profiling::time::NowMicros();
-  int64_t startup_latency_us = initialization_end_us - initialization_start_us;
-  TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3
-                   << "ms";
-
-  uint64_t input_bytes = ComputeInputBytes();
-  Stat<int64_t> warmup_time_us =
-      Run(params_.Get<int32_t>("warmup_runs"), WARMUP);
-  Stat<int64_t> inference_time_us =
-      Run(params_.Get<int32_t>("num_runs"), REGULAR);
-  listeners_.OnBenchmarkEnd(
-      {startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
-}
-
-bool BenchmarkModel::ParseFlags(int argc, char **argv) {
-  auto flag_list = GetFlags();
-  const bool parse_result =
-      Flags::Parse(&argc, const_cast<const char **>(argv), flag_list);
-  if (!parse_result) {
-    std::string usage = Flags::Usage(argv[0], flag_list);
-    TFLITE_LOG(ERROR) << usage;
-    return false;
-  }
-  return ValidateFlags();
-}
-
-}  // namespace benchmark
-}  // namespace nnfw
diff --git a/tools/tflite_benchmark_model/benchmark_model.h b/tools/tflite_benchmark_model/benchmark_model.h
deleted file mode 100644 (file)
index 5645e29..0000000
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_BENCHMARK_MODEL_H__
-#define __TFLITE_BENCHMARK_MODEL_BENCHMARK_MODEL_H__
-
-#include <cmath>
-#include <limits>
-#include <ostream>
-#include <string>
-#include <unordered_set>
-#include <vector>
-
-#include "benchmark_params.h"
-#include "command_line_flags.h"
-#include "tensorflow/core/util/stats_calculator.h"
-
-namespace nnfw {
-namespace benchmark {
-
-enum RunType {
-  WARMUP,
-  REGULAR,
-};
-
-class BenchmarkResults {
- public:
-  BenchmarkResults(int64_t startup_latency_us, uint64_t input_bytes,
-                   tensorflow::Stat<int64_t> warmup_time_us,
-                   tensorflow::Stat<int64_t> inference_time_us)
-      : startup_latency_us_(startup_latency_us),
-        input_bytes_(input_bytes),
-        warmup_time_us_(warmup_time_us),
-        inference_time_us_(inference_time_us) {}
-
-  tensorflow::Stat<int64_t> inference_time_us() const {
-    return inference_time_us_;
-  }
-  tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
-  int64_t startup_latency_us() const { return startup_latency_us_; }
-  uint64_t input_bytes() const { return input_bytes_; }
-  double throughput_MB_per_second() const {
-    double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
-                           inference_time_us_.sum();
-    return bytes_per_sec / (1024.0 * 1024.0);
-  }
-
- private:
-  int64_t startup_latency_us_;
-  uint64_t input_bytes_;
-  tensorflow::Stat<int64_t> warmup_time_us_;
-  tensorflow::Stat<int64_t> inference_time_us_;
-};
-
-class BenchmarkListener {
- public:
-  virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
-  virtual void OnSingleRunStart(RunType runType) {}
-  virtual void OnSingleRunEnd() {}
-  virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
-  virtual ~BenchmarkListener() {}
-};
-
-// A listener that forwards its method calls to a collection of listeners.
-class BenchmarkListeners : public BenchmarkListener {
- public:
-  // Added a listener to the listener collection.
-  // |listener| is not owned by the instance of |BenchmarkListeners|.
-  // |listener| should not be null and should outlast the instance of
-  // |BenchmarkListeners|.
-  void AddListener(BenchmarkListener* listener) {
-    listeners_.push_back(listener);
-  }
-
-  void OnBenchmarkStart(const BenchmarkParams& params) override {
-    for (auto listener : listeners_) {
-      listener->OnBenchmarkStart(params);
-    }
-  }
-
-  void OnSingleRunStart(RunType runType) override {
-    for (auto listener : listeners_) {
-      listener->OnSingleRunStart(runType);
-    }
-  }
-
-  void OnSingleRunEnd() override {
-    for (auto listener : listeners_) {
-      listener->OnSingleRunEnd();
-    }
-  }
-
-  void OnBenchmarkEnd(const BenchmarkResults& results) override {
-    for (auto listener : listeners_) {
-      listener->OnBenchmarkEnd(results);
-    }
-  }
-
-  ~BenchmarkListeners() {}
-
- private:
-  // Use vector so listeners are invoked in the order they are added.
-  std::vector<BenchmarkListener*> listeners_;
-};
-
-// Benchmark listener that just logs the results of benchmark run.
-class BenchmarkLoggingListener : public BenchmarkListener {
-  void OnBenchmarkEnd(const BenchmarkResults& results) override;
-};
-
-template <typename T>
-Flag CreateFlag(const char* name, BenchmarkParams* params,
-                const std::string& usage) {
-  return Flag(name, [params, name](const T& val) { params->Set<T>(name, val); },
-              params->Get<T>(name), usage);
-}
-
-// Benchmarks a model.
-//
-// Subclasses need to implement initialization and running of the model.
-// The results can be collected by adding BenchmarkListener(s).
-class BenchmarkModel {
- public:
-  static BenchmarkParams DefaultParams();
-  BenchmarkModel();
-  BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {}
-  virtual ~BenchmarkModel() {}
-  bool ParseFlags(int argc, char** argv);
-  virtual void Init() = 0;
-  void Run(int argc, char** argv);
-  void AddListener(BenchmarkListener* listener) {
-    listeners_.AddListener(listener);
-  }
-
- protected:
-  virtual void LogFlags();
-  virtual bool ValidateFlags() { return true; }
-  virtual std::vector<Flag> GetFlags();
-  virtual uint64_t ComputeInputBytes() = 0;
-  virtual tensorflow::Stat<int64_t> Run(int num_times, RunType run_type);
-  virtual void RunImpl() = 0;
-  BenchmarkParams params_;
-  BenchmarkListeners listeners_;
-};
-
-}  // namespace benchmark
-}  // namespace nnfw
-
-#endif  //__TFLITE_BENCHMARK_MODEL_BENCHMARK_MODEL_H__
diff --git a/tools/tflite_benchmark_model/benchmark_params.cc b/tools/tflite_benchmark_model/benchmark_params.cc
deleted file mode 100644 (file)
index 7b667a4..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "benchmark_params.h"
-
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "logging.h"
-
-namespace nnfw {
-namespace benchmark {
-
-void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a,
-                                       BenchmarkParam::ParamType b) {
-  TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter.";
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<int32_t>() {
-  return BenchmarkParam::ParamType::TYPE_INT32;
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<bool>() {
-  return BenchmarkParam::ParamType::TYPE_BOOL;
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<float>() {
-  return BenchmarkParam::ParamType::TYPE_FLOAT;
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<std::string>() {
-  return BenchmarkParam::ParamType::TYPE_STRING;
-}
-
-void BenchmarkParams::AssertParamExists(const std::string& name) const {
-  TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found.";
-}
-
-}  // namespace benchmark
-}  // namespace nnfw
diff --git a/tools/tflite_benchmark_model/benchmark_params.h b/tools/tflite_benchmark_model/benchmark_params.h
deleted file mode 100644 (file)
index 1ac3f4a..0000000
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_BENCHMARK_PARAMS_H__
-#define __TFLITE_BENCHMARK_MODEL_BENCHMARK_PARAMS_H__
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "logging.h"
-
-namespace nnfw {
-namespace benchmark {
-
-template <typename T>
-class TypedBenchmarkParam;
-
-class BenchmarkParam {
- protected:
-  enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
-
- public:
-  template <typename T>
-  static std::unique_ptr<BenchmarkParam> Create(const T& default_value) {
-    return std::unique_ptr<BenchmarkParam>(
-        new TypedBenchmarkParam<T>(default_value));
-  }
-
-  template <typename T>
-  TypedBenchmarkParam<T>* AsTyped() {
-    AssertHasSameType(GetValueType<T>(), type_);
-    return static_cast<TypedBenchmarkParam<T>*>(this);
-  }
-  virtual ~BenchmarkParam() {}
-  BenchmarkParam(ParamType type) : type_(type) {}
-
- private:
-  static void AssertHasSameType(ParamType a, ParamType b);
- protected:
-  template <typename T>
-  static ParamType GetValueType();
-
-  const ParamType type_;
-};
-
-template <typename T>
-class TypedBenchmarkParam : public BenchmarkParam {
- public:
-  TypedBenchmarkParam(const T& value)
-      : BenchmarkParam(GetValueType<T>()), value_(value) {}
-  void Set(const T& value) { value_ = value; }
-
-  T Get() { return value_; }
-
- private:
-  T value_;
-};
-
-class BenchmarkParams {
- public:
-  void AddParam(const std::string& name,
-                std::unique_ptr<BenchmarkParam> value) {
-    params_[name] = std::move(value);
-  }
-
-  bool HasParam(const std::string& name) const {
-    return params_.find(name) != params_.end();
-  }
-
-  template <typename T>
-  void Set(const std::string& name, const T& value) {
-    AssertParamExists(name);
-    params_.at(name)->AsTyped<T>()->Set(value);
-  }
-
-  template <typename T>
-  T Get(const std::string& name) const {
-    AssertParamExists(name);
-    return params_.at(name)->AsTyped<T>()->Get();
-  }
-
- private:
-  void AssertParamExists(const std::string& name) const;
-  std::unordered_map<std::string, std::unique_ptr<BenchmarkParam>> params_;
-};
-
-}  // namespace benchmark
-}  // namespace nnfw
-#endif  // __TFLITE_BENCHMARK_MODEL_BENCHMARK_PARAMS_H__
index d277795..611bd6a 100644 (file)
@@ -29,7 +29,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "benchmark_tflite_model.h"
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h"
 
 #include <cstdarg>
 #include <cstdlib>
@@ -39,11 +39,16 @@ limitations under the License.
 #include <unordered_set>
 #include <vector>
 
+#ifdef TFLITE_FLEX
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
+#endif  // TFLITE_FLEX
 #include "support/tflite/kernels/register.h"
 #include "tensorflow/contrib/lite/model.h"
 #include "tensorflow/contrib/lite/op_resolver.h"
 #include "tensorflow/contrib/lite/string_util.h"
-#include "logging.h"
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
+
+// For profiling nnapi_delegate
 #include "util/profiling/profiling.h"
 #include "support/tflite/nnapi_delegate.h"
 
@@ -51,7 +56,7 @@ limitations under the License.
 void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
 #endif
 
-namespace nnfw {
+namespace tflite {
 namespace benchmark {
 
 void ProfilingListener::SetInterpreter(tflite::Interpreter* interpreter) {
@@ -130,7 +135,7 @@ void FillRandomValue(T* ptr, const std::vector<int>& sizes,
 
 void FillRandomString(tflite::DynamicBuffer* buffer,
                       const std::vector<int>& sizes,
-                      const std::function<std::string()>& random_func) {
+                      const std::function<string()>& random_func) {
   int num_elements = 1;
   for (int dim : sizes) {
     num_elements *= dim;
@@ -142,7 +147,7 @@ void FillRandomString(tflite::DynamicBuffer* buffer,
 }
 
 bool PopulateInputLayerInfo(
-    const std::string& names_string, const std::string& shapes_string,
+    const string& names_string, const string& shapes_string,
     std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
   std::vector<std::string> names = Split(names_string, ',');
   std::vector<std::string> shapes = Split(shapes_string, ':');
@@ -216,8 +221,8 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
   return flags;
 }
 
-void BenchmarkTfLiteModel::LogFlags() {
-  BenchmarkModel::LogFlags();
+void BenchmarkTfLiteModel::LogParams() {
+  BenchmarkModel::LogParams();
   TFLITE_LOG(INFO) << "Graph: [" << params_.Get<std::string>("graph") << "]";
   TFLITE_LOG(INFO) << "Input layers: ["
                    << params_.Get<std::string>("input_layer") << "]";
@@ -226,7 +231,7 @@ void BenchmarkTfLiteModel::LogFlags() {
   TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
 }
 
-bool BenchmarkTfLiteModel::ValidateFlags() {
+bool BenchmarkTfLiteModel::ValidateParams() {
   if (params_.Get<std::string>("graph").empty()) {
     TFLITE_LOG(ERROR)
         << "Please specify the name of your TF Lite input file with --graph";
@@ -247,6 +252,46 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
   return total_input_bytes;
 }
 
+void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
+  auto interpreter_inputs = interpreter->inputs();
+  // Set the values of the input tensors.
+  for (int j = 0; j < inputs.size(); ++j) {
+    const InputLayerInfo& input = inputs[j];
+    int i = interpreter_inputs[j];
+    TfLiteTensor* t = interpreter->tensor(i);
+    std::vector<int> sizes = input.shape;
+
+    // TODO(ahentz): below we ignore the O-th dimension (number of batches).
+    if (t->type == kTfLiteFloat32) {
+      FillRandomValue<float>(
+          interpreter->typed_tensor<float>(i),
+          std::vector<int>(sizes.begin() + 1, sizes.end()),
+          []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
+    } else if (t->type == kTfLiteInt32) {
+      // TODO(yunluli): This is currently only used for handling embedding input
+      // for speech models. Generalize if necessary.
+      FillRandomValue<int32_t>(
+          interpreter->typed_tensor<int32_t>(i),
+          std::vector<int32_t>(sizes.begin() + 1, sizes.end()),
+          []() { return static_cast<int32_t>(rand()) % 100; });
+    } else if (t->type == kTfLiteUInt8) {
+      FillRandomValue<uint8_t>(
+          interpreter->typed_tensor<uint8_t>(i),
+          std::vector<int>(sizes.begin() + 1, sizes.end()),
+          []() { return static_cast<uint8_t>(rand()) % 255; });
+    } else if (t->type == kTfLiteString) {
+      tflite::DynamicBuffer buffer;
+      FillRandomString(&buffer, sizes, []() {
+        return "we're have some friends over saturday to hang out in the yard";
+      });
+      buffer.WriteToTensor(interpreter->tensor(i));
+    } else {
+      TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
+                        << " of type " << t->type;
+    }
+  }
+}
+
 void BenchmarkTfLiteModel::Init() {
   std::string graph = params_.Get<std::string>("graph");
   model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
@@ -269,7 +314,7 @@ void BenchmarkTfLiteModel::Init() {
     TFLITE_LOG(FATAL) << "Failed to construct interpreter";
   }
   profiling_listener_.SetInterpreter(interpreter.get());
-  profiling::Context::get().setProfiler(interpreter->GetProfiler());
+  ::profiling::Context::get().setProfiler(interpreter->GetProfiler());
 
   const int32_t num_threads = params_.Get<int32_t>("num_threads");
 
@@ -280,6 +325,16 @@ void BenchmarkTfLiteModel::Init() {
   bool use_nnapi = params_.Get<bool>("use_nnapi");
 
   interpreter->UseNNAPI(use_nnapi);
+
+#ifdef TFLITE_FLEX
+  TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
+  delegate_ = FlexDelegate::Create();
+  if (delegate_) {
+    interpreter->ModifyGraphWithDelegate(delegate_.get(),
+                                         /*allow_dynamic_tensors=*/true);
+  }
+#endif  // TFLITE_FLEX
+
   auto interpreter_inputs = interpreter->inputs();
 
   if (!inputs.empty()) {
@@ -311,36 +366,6 @@ void BenchmarkTfLiteModel::Init() {
   if (interpreter->AllocateTensors() != kTfLiteOk) {
     TFLITE_LOG(FATAL) << "Failed to allocate tensors!";
   }
-
-  // Set the values of the input tensors.
-  for (int j = 0; j < inputs.size(); ++j) {
-    const InputLayerInfo& input = inputs[j];
-    int i = interpreter_inputs[j];
-    TfLiteTensor* t = interpreter->tensor(i);
-    std::vector<int> sizes = input.shape;
-
-    // TODO(ahentz): below we ignore the O-th dimension (number of batches).
-    if (t->type == kTfLiteFloat32) {
-      FillRandomValue<float>(
-          interpreter->typed_tensor<float>(i),
-          std::vector<int>(sizes.begin() + 1, sizes.end()),
-          []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
-    } else if (t->type == kTfLiteUInt8) {
-      FillRandomValue<uint8_t>(
-          interpreter->typed_tensor<uint8_t>(i),
-          std::vector<int>(sizes.begin() + 1, sizes.end()),
-          []() { return static_cast<uint8_t>(rand()) % 255; });
-    } else if (t->type == kTfLiteString) {
-      tflite::DynamicBuffer buffer;
-      FillRandomString(&buffer, sizes, []() {
-        return "we're have some friends over saturday to hang out in the yard";
-      });
-      buffer.WriteToTensor(interpreter->tensor(i));
-    } else {
-      TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
-                        << " of type " << t->type;
-    }
-  }
 }
 
 void BenchmarkTfLiteModel::RunImpl() {
@@ -357,4 +382,4 @@ void BenchmarkTfLiteModel::RunImpl() {
 }
 
 }  // namespace benchmark
-}  // namespace nnfw
+}  // namespace tflite
diff --git a/tools/tflite_benchmark_model/benchmark_tflite_model.h b/tools/tflite_benchmark_model/benchmark_tflite_model.h
deleted file mode 100644 (file)
index 7892de1..0000000
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_BENCHMARK_TFLITE_MODEL_H__
-#define __TFLITE_BENCHMARK_MODEL_BENCHMARK_TFLITE_MODEL_H__
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
-#include "benchmark_model.h"
-
-namespace nnfw {
-namespace benchmark {
-
-// Dumps profiling events if profiling is enabled
-class ProfilingListener : public BenchmarkListener {
- public:
-  explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {}
-
-  void SetInterpreter(tflite::Interpreter* interpreter);
-
-  void OnSingleRunStart(RunType run_type) override;
-
-  void OnSingleRunEnd() override;
-
-  void OnBenchmarkEnd(const BenchmarkResults& results) override;
-
- private:
-  tflite::Interpreter* interpreter_;
-  tflite::profiling::Profiler profiler_;
-  tflite::profiling::ProfileSummarizer summarizer_;
-  bool has_profiles_;
-};
-
-// Benchmarks a TFLite model by running tflite interpreter.
-class BenchmarkTfLiteModel : public BenchmarkModel {
- public:
-  BenchmarkTfLiteModel();
-  BenchmarkTfLiteModel(BenchmarkParams params);
-
-  std::vector<Flag> GetFlags() override;
-  void LogFlags() override;
-  bool ValidateFlags() override;
-  uint64_t ComputeInputBytes() override;
-  void Init() override;
-  void RunImpl() override;
-  virtual ~BenchmarkTfLiteModel() {}
-
-  struct InputLayerInfo {
-    std::string name;
-    std::vector<int> shape;
-  };
-
- private:
-  std::unique_ptr<tflite::FlatBufferModel> model;
-  std::unique_ptr<tflite::Interpreter> interpreter;
-  std::vector<InputLayerInfo> inputs;
-  ProfilingListener profiling_listener_;
-};
-
-}  // namespace benchmark
-}  // namespace nnfw
-
-#endif  //__TFLITE_BENCHMARK_MODEL_BENCHMARK_TFLITE_MODEL_H__
diff --git a/tools/tflite_benchmark_model/command_line_flags.cc b/tools/tflite_benchmark_model/command_line_flags.cc
deleted file mode 100644 (file)
index eacca9f..0000000
+++ /dev/null
@@ -1,214 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "command_line_flags.h"
-
-#include <cstring>
-#include <sstream>
-#include <string>
-#include <utility>
-#include <vector>
-
-namespace nnfw {
-namespace {
-
-template <typename T>
-std::string ToString(T val) {
-  std::ostringstream stream;
-  stream << val;
-  return stream.str();
-}
-
-bool ParseFlag(const std::string& arg, const std::string& flag,
-               const std::function<bool(const std::string&)>& parse_func,
-               bool* value_parsing_ok) {
-  *value_parsing_ok = true;
-  std::string flag_prefix = "--" + flag + "=";
-  if (arg.find(flag_prefix) != 0) {
-    return false;
-  }
-  bool has_value = arg.size() >= flag_prefix.size();
-  *value_parsing_ok = has_value;
-  if (has_value) {
-    *value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
-  }
-  return true;
-}
-
-template <typename T>
-bool ParseFlag(const std::string& flag_value,
-               const std::function<void(const T&)>& hook) {
-  std::istringstream stream(flag_value);
-  T read_value;
-  stream >> read_value;
-  if (!stream.eof() && !stream.good()) {
-    return false;
-  }
-  hook(read_value);
-  return true;
-}
-
-bool ParseBoolFlag(const std::string& flag_value,
-                   const std::function<void(const bool&)>& hook) {
-  if (flag_value != "true" && flag_value != "false") {
-    return false;
-  }
-
-  hook(flag_value == "true");
-  return true;
-}
-}  // namespace
-
-Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
-           int32_t default_value, const std::string& usage_text)
-    : name_(name),
-      type_(TYPE_INT32),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseFlag<int32_t>(flag_value, hook);
-      }),
-      default_for_display_(ToString(default_value)),
-      usage_text_(usage_text) {}
-
-Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
-           int64_t default_value, const std::string& usage_text)
-    : name_(name),
-      type_(TYPE_INT64),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseFlag<int64_t>(flag_value, hook);
-      }),
-      default_for_display_(ToString(default_value)),
-      usage_text_(usage_text) {}
-
-Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
-           float default_value, const std::string& usage_text)
-    : name_(name),
-      type_(TYPE_FLOAT),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseFlag<float>(flag_value, hook);
-      }),
-      default_for_display_(ToString(default_value)),
-      usage_text_(usage_text) {}
-
-Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
-           bool default_value, const std::string& usage_text)
-    : name_(name),
-      type_(TYPE_BOOL),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseBoolFlag(flag_value, hook);
-      }),
-      default_for_display_(default_value ? "true" : "false"),
-      usage_text_(usage_text) {}
-
-Flag::Flag(const char* name,
-           const std::function<void(const std::string&)>& hook,
-           const std::string& default_value, const std::string& usage_text)
-    : name_(name),
-      type_(TYPE_STRING),
-      value_hook_([hook](const std::string& flag_value) {
-        hook(flag_value);
-        return true;
-      }),
-      default_for_display_(default_value),
-      usage_text_(usage_text) {}
-
-bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
-  return ParseFlag(arg, name_, value_hook_, value_parsing_ok);
-}
-
-std::string Flag::GetTypeName() const {
-  switch (type_) {
-    case TYPE_INT32:
-      return "int32";
-    case TYPE_INT64:
-      return "int64";
-    case TYPE_FLOAT:
-      return "float";
-    case TYPE_BOOL:
-      return "bool";
-    case TYPE_STRING:
-      return "string";
-  }
-
-  return "unknown";
-}
-
-/*static*/ bool Flags::Parse(int* argc, const char** argv,
-                             const std::vector<Flag>& flag_list) {
-  bool result = true;
-  std::vector<const char*> unknown_flags;
-  for (int i = 1; i < *argc; ++i) {
-    if (std::string(argv[i]) == "--") {
-      while (i < *argc) {
-        unknown_flags.push_back(argv[i]);
-        ++i;
-      }
-      break;
-    }
-
-    bool was_found = false;
-    for (const Flag& flag : flag_list) {
-      bool value_parsing_ok;
-      was_found = flag.Parse(argv[i], &value_parsing_ok);
-      if (!value_parsing_ok) {
-        result = false;
-      }
-      if (was_found) {
-        break;
-      }
-    }
-    if (!was_found) {
-      unknown_flags.push_back(argv[i]);
-    }
-  }
-  int dst = 1;  // Skip argv[0]
-  for (auto f : unknown_flags) {
-    argv[dst++] = f;
-  }
-  argv[dst++] = nullptr;
-  *argc = unknown_flags.size() + 1;
-  return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
-}
-
-/*static*/ std::string Flags::Usage(const std::string& cmdline,
-                                    const std::vector<Flag>& flag_list) {
-  std::ostringstream usage_text;
-  usage_text << "usage: " << cmdline << "\n";
-  if (!flag_list.empty()) {
-    usage_text << "Flags:\n";
-  }
-
-  for (const Flag& flag : flag_list) {
-    auto type_name = flag.GetTypeName();
-    usage_text << "\t";
-    usage_text << "--" << flag.name_ << "=" << flag.default_for_display_;
-    usage_text << "\t" << type_name << "\t" << flag.usage_text_ << "\n";
-  }
-  return usage_text.str();
-}
-
-}  // namespace nnfw
diff --git a/tools/tflite_benchmark_model/command_line_flags.h b/tools/tflite_benchmark_model/command_line_flags.h
deleted file mode 100644 (file)
index 766417d..0000000
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_COMMAND_LINE_FLAGS_H__
-#define __TFLITE_BENCHMARK_MODEL_COMMAND_LINE_FLAGS_H__
-
-#include <functional>
-#include <string>
-#include <vector>
-
-namespace nnfw {
-// A simple command-line argument parsing module.
-// Dependency free simplified port of core/util/command_line_flags.
-// This class is written for benchmarks and uses inefficient string
-// concatenation. This was written to avoid dependency on tensorflow/core/util
-// which transitively brings in a lot of other dependencies that are not
-// necessary for tflite benchmarking code.
-// The recommended way of using it is with local variables and an initializer
-// list of Flag objects, for example:
-//
-// int some_int = 10;
-// bool some_switch = false;
-// std::string some_name = "something";
-//
-// std::vector<tensorFlow::Flag> flag_list = {
-//   Flag::CreateFlag("some_int", &some_int, "an integer that affects X"),
-//   Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"),
-//   Flag::CreateFlag("some_name", &some_name, "a string that affects Z")
-// };
-// // Get usage message before ParseFlags() to capture default values.
-// std::string usage = Flag::Usage(argv[0], flag_list);
-// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list);
-//
-// tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
-// if (argc != 1 || !parsed_values_ok) {
-//    ...output usage and error message...
-// }
-//
-// The argc and argv values are adjusted by the Parse function so all that
-// remains is the program name (at argv[0]) and any unknown arguments fill the
-// rest of the array. This means you can check for flags that weren't understood
-// by seeing if argv is greater than 1.
-// The result indicates if there were any errors parsing the values that were
-// passed to the command-line switches. For example, --some_int=foo would return
-// false because the argument is expected to be an integer.
-//
-// NOTE: Unlike gflags-style libraries, this library is intended to be
-// used in the `main()` function of your binary. It does not handle
-// flag definitions that are scattered around the source code.
-
-// A description of a single command line flag, holding its name, type, usage
-// text, and a pointer to the corresponding variable.
-class Flag {
- public:
-  template <typename T>
-  static Flag CreateFlag(const char* name, T* val, const char* usage) {
-    return Flag(name, [val](const T& v) { *val = v; }, *val, usage);
-  }
-
-  Flag(const char* name, const std::function<void(const int32_t&)>& hook,
-       int32_t default_value, const std::string& usage_text);
-  Flag(const char* name, const std::function<void(const int64_t&)>& hook,
-       int64_t default_value, const std::string& usage_text);
-  Flag(const char* name, const std::function<void(const float&)>& hook,
-       float default_value, const std::string& usage_text);
-  Flag(const char* name, const std::function<void(const bool&)>& hook,
-       bool default_value, const std::string& usage_text);
-  Flag(const char* name, const std::function<void(const std::string&)>& hook,
-       const std::string& default_value, const std::string& usage_text);
-
- private:
-  friend class Flags;
-
-  bool Parse(const std::string& arg, bool* value_parsing_ok) const;
-
-  std::string name_;
-  enum {
-    TYPE_INT32,
-    TYPE_INT64,
-    TYPE_BOOL,
-    TYPE_STRING,
-    TYPE_FLOAT,
-  } type_;
-
-  std::string GetTypeName() const;
-
-  std::function<bool(const std::string&)> value_hook_;
-  std::string default_for_display_;
-
-  std::string usage_text_;
-};
-
-class Flags {
- public:
-  // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag
-  // instances matching flags in flaglist[].  Update the variables associated
-  // with matching flags, and remove the matching arguments from (*argc, argv).
-  // Return true iff all recognized flag values were parsed correctly, and the
-  // first remaining argument is not "--help".
-  static bool Parse(int* argc, const char** argv,
-                    const std::vector<Flag>& flag_list);
-
-  // Return a usage message with command line cmdline, and the
-  // usage_text strings in flag_list[].
-  static std::string Usage(const std::string& cmdline,
-                           const std::vector<Flag>& flag_list);
-};
-
-}  // namespace nnfw
-
-#endif  // __TFLITE_BENCHMARK_MODEL_COMMAND_LINE_FLAGS_H__
-
-
diff --git a/tools/tflite_benchmark_model/logging.h b/tools/tflite_benchmark_model/logging.h
deleted file mode 100644 (file)
index e694a09..0000000
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_LOGGING_H_
-#define __TFLITE_BENCHMARK_MODEL_LOGGING_H_
-
-// LOG and CHECK macros for benchmarks.
-
-#include <cstdlib>
-#include <iostream>
-#include <sstream>
-
-namespace nnfw {
-namespace logging {
-// A wrapper that logs to stderr.
-//
-// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros.
-class LoggingWrapper {
- public:
-  enum class LogSeverity : int {
-    INFO = 0,
-    WARN = 1,
-    ERROR = 2,
-    FATAL = 3,
-  };
-  LoggingWrapper(LogSeverity severity)
-      : severity_(severity), should_log_(true) {}
-  LoggingWrapper(LogSeverity severity, bool log)
-      : severity_(severity), should_log_(log) {}
-  std::stringstream& Stream() { return stream_; }
-  ~LoggingWrapper() {
-    if (should_log_) {
-      std::cerr << stream_.str() << std::endl;
-      if (severity_ == LogSeverity::FATAL) {
-        std::flush(std::cerr);
-        std::abort();
-      }
-    }
-  }
-
- private:
-  std::stringstream stream_;
-  LogSeverity severity_;
-  bool should_log_;
-};
-
-}  // namespace logging
-
-}  // namespace nnfw
-
-#define TFLITE_LOG(severity)                                  \
-  nnfw::logging::LoggingWrapper(                            \
-      nnfw::logging::LoggingWrapper::LogSeverity::severity) \
-      .Stream()
-
-#define TFLITE_BENCHMARK_CHECK(condition)                  \
-  nnfw::logging::LoggingWrapper(                         \
-      nnfw::logging::LoggingWrapper::LogSeverity::FATAL, \
-      (condition) ? false : true)                          \
-      .Stream()
-
-#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b)
-
-#endif  // __TFLITE_BENCHMARK_MODEL_BENCHMARK_LOGGING_H_
index 4d12b50..ce19b0c 100644 (file)
@@ -39,8 +39,6 @@ namespace tflite {
 namespace profiling {
 namespace {
 
-using Detail = tensorflow::StatsCalculator::Detail;
-
 struct OperatorDetails {
   std::string name;
   std::vector<std::string> inputs;
@@ -94,18 +92,30 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
   } else {
     op_name = tflite::EnumNamesBuiltinOperator()[code];
   }
+  const char* profiling_string =
+      interpreter.OpProfilingString(node_reg->second, &node_reg->first);
   OperatorDetails details;
   details.name = op_name;
+  if (profiling_string) {
+    details.name += ":" + std::string(profiling_string);
+  }
   details.inputs = GetTensorNames(interpreter, inputs);
   details.outputs = GetTensorNames(interpreter, outputs);
   return details;
 }
 
+tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() {
+  auto options = tensorflow::StatSummarizerOptions();
+  options.show_summary = true;
+  options.show_memory = false;
+  return options;
+}
+
 }  // namespace
 
 ProfileSummarizer::ProfileSummarizer()
-    : stats_calculator_(new ::tensorflow::StatsCalculator(
-          tensorflow::StatSummarizerOptions())) {}
+    : stats_calculator_(
+          new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {}
 
 void ProfileSummarizer::ProcessProfiles(
     const std::vector<const ProfileEvent*>& profile_stats,
@@ -129,35 +139,22 @@ void ProfileSummarizer::ProcessProfiles(
   int64_t base_start_us = events[0]->begin_timestamp_us;
   int node_num = 0;
   int64_t curr_total_us = 0;
-  std::map<std::string, Detail> details;
   int prev_op_idx = -1;
-  int seq_no = 1;
+  int child_op_no = 1;
   for (auto event : events) {
     auto op_details = GetOperatorDetails(interpreter, event->event_metadata);
-    bool is_continued = (prev_op_idx == event->event_metadata);
-    seq_no = is_continued ? seq_no + 1 : 1;
-    auto node_name = ToString(op_details.outputs) + "#" + std::to_string(seq_no);
-    auto result = details.emplace(node_name, Detail());
-    Detail* detail = &(result.first->second);
-    detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us);
+    bool from_same_op = (prev_op_idx == event->event_metadata);
+    child_op_no = from_same_op ? child_op_no + 1 : 1;
+    auto node_name = ToString(op_details.outputs) + "#" + std::to_string(child_op_no);
+    int64_t start_us = event->begin_timestamp_us - base_start_us;
     int64_t node_exec_time =
         event->end_timestamp_us - event->begin_timestamp_us;
-    detail->rel_end_us.UpdateStat(node_exec_time);
+    stats_calculator_->AddNodeStats(node_name, op_details.name, node_num,
+                                    start_us, node_exec_time, 0 /*memory */);
     curr_total_us += node_exec_time;
     ++node_num;
-
-    if (result.second) {
-      detail->name = node_name;
-      detail->type = op_details.name;
-      detail->run_order = node_num;
-      detail->times_called = 0;
-    }
-    if (!is_continued) {
-      ++detail->times_called;
-    }
     prev_op_idx = event->event_metadata;
   }
-  stats_calculator_->UpdateDetails(details);
   stats_calculator_->UpdateRunTotalUs(curr_total_us);
 }
 }  // namespace profiling
diff --git a/tools/tflite_benchmark_model/profile_summarizer.h b/tools/tflite_benchmark_model/profile_summarizer.h
deleted file mode 100644 (file)
index a529ff8..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
-#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
-
-#include <vector>
-
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/profiling/profiler.h"
-#include "tensorflow/core/util/stats_calculator.h"
-
-namespace tflite {
-namespace profiling {
-
-// Creates a summary of operator invocations in the interpreter.
-class ProfileSummarizer {
- public:
-  ProfileSummarizer();
-  virtual ~ProfileSummarizer() {}
-
-  // Process profile events to update statistics for operator invocations.
-  void ProcessProfiles(const std::vector<const ProfileEvent*>& profile_stats,
-                       const tflite::Interpreter& interpreter);
-
-  // Returns a string detailing the accumulated runtime stats in a tab-separated
-  // format which can be pasted into a spreadsheet for further analysis.
-  std::string GetOutputString() const {
-    return stats_calculator_->GetOutputString();
-  }
-
-  std::string GetShortSummary() const {
-    return stats_calculator_->GetShortSummary();
-  }
-
- private:
-  std::unique_ptr<tensorflow::StatsCalculator> stats_calculator_;
-};
-
-}  // namespace profiling
-}  // namespace tflite
-
-#endif  // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
diff --git a/tools/tflite_benchmark_model/stats_calculator.cc b/tools/tflite_benchmark_model/stats_calculator.cc
new file mode 100644 (file)
index 0000000..5786507
--- /dev/null
@@ -0,0 +1,317 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/stats_calculator.h"
+
+#include <iomanip>
+#include <map>
+#include <queue>
+#include <sstream>
+#include <string>
+#include <algorithm>
+
+namespace tensorflow {
+
+StatsCalculator::StatsCalculator(const StatSummarizerOptions& options)
+    : options_(options) {}
+
+std::string StatsCalculator::GetShortSummary() const {
+  std::stringstream stream;
+  stream << "Timings (microseconds): ";
+  run_total_us_.OutputToStream(&stream);
+  stream << std::endl;
+
+  stream << "Memory (bytes): ";
+  memory_.OutputToStream(&stream);
+  stream << std::endl;
+
+  stream << details_.size() << " nodes observed" << std::endl;
+  return stream.str();
+}
+
+std::ostream& InitField(std::ostream& stream, int width) {
+  stream << "\t" << std::right << std::setw(width) << std::fixed
+         << std::setprecision(3);
+  return stream;
+}
+
+std::string StatsCalculator::HeaderString(const std::string& title) const {
+  std::stringstream stream;
+
+  stream << "============================== " << title
+         << " ==============================" << std::endl;
+
+  InitField(stream, 24) << "[node type]";
+  InitField(stream, 9) << "[start]";
+  InitField(stream, 9) << "[first]";
+  InitField(stream, 9) << "[avg ms]";
+  InitField(stream, 8) << "[%]";
+  InitField(stream, 8) << "[cdf%]";
+  InitField(stream, 10) << "[mem KB]";
+  InitField(stream, 9) << "[times called]";
+  stream << "\t"
+         << "[Name]";
+  return stream.str();
+}
+
+std::string StatsCalculator::ColumnString(const Detail& detail,
+                                          const int64_t cumulative_stat_on_node,
+                                          const Stat<int64_t>& stat) const {
+  const double start_ms = detail.start_us.avg() / 1000.0;
+  const double first_time_ms = detail.rel_end_us.first() / 1000.0;
+  const double avg_time_ms = detail.rel_end_us.avg() / 1000.0;
+  const double percentage = detail.rel_end_us.sum() * 100.0 / stat.sum();
+  const double cdf_percentage = (cumulative_stat_on_node * 100.0f) / stat.sum();
+  const int64_t times_called = detail.times_called / num_runs();
+
+  std::stringstream stream;
+  InitField(stream, 24) << detail.type;
+  InitField(stream, 9) << start_ms;
+  InitField(stream, 9) << first_time_ms;
+  InitField(stream, 9) << avg_time_ms;
+  InitField(stream, 7) << percentage << "%";
+  InitField(stream, 7) << cdf_percentage << "%";
+  InitField(stream, 10) << detail.mem_used.newest() / 1000.0;
+  InitField(stream, 9) << times_called;
+  stream << "\t" << detail.name;
+
+  return stream.str();
+}
+
+void StatsCalculator::OrderNodesByMetric(
+    SortingMetric metric, std::vector<const Detail*>* details) const {
+  std::priority_queue<std::pair<std::string, const Detail*>> sorted_list;
+  const int num_nodes = details_.size();
+
+  for (const auto& det : details_) {
+    const Detail* detail = &(det.second);
+    std::stringstream stream;
+    stream << std::setw(20) << std::right << std::setprecision(10)
+           << std::fixed;
+
+    switch (metric) {
+      case BY_NAME:
+        stream << detail->name;
+        break;
+      case BY_RUN_ORDER:
+        stream << num_nodes - detail->run_order;
+        break;
+      case BY_TIME:
+        stream << detail->rel_end_us.avg();
+        break;
+      case BY_MEMORY:
+        stream << detail->mem_used.avg();
+        break;
+      case BY_TYPE:
+        stream << detail->type;
+        break;
+      default:
+        stream << "";
+        break;
+    }
+
+    sorted_list.emplace(stream.str(), detail);
+  }
+
+  while (!sorted_list.empty()) {
+    auto entry = sorted_list.top();
+    sorted_list.pop();
+    details->push_back(entry.second);
+  }
+}
+
+void StatsCalculator::ComputeStatsByType(
+    std::map<std::string, int64_t>* node_type_map_count,
+    std::map<std::string, int64_t>* node_type_map_time,
+    std::map<std::string, int64_t>* node_type_map_memory,
+    std::map<std::string, int64_t>* node_type_map_times_called,
+    int64_t* accumulated_us) const {
+  int64_t run_count = run_total_us_.count();
+
+  for (const auto& det : details_) {
+    const std::string node_name = det.first;
+    const Detail& detail = det.second;
+
+    int64_t curr_time_val =
+        static_cast<int64_t>(detail.rel_end_us.sum() / run_count);
+    *accumulated_us += curr_time_val;
+
+    int64_t curr_memory_val = detail.mem_used.newest();
+
+    const std::string& node_type = detail.type;
+
+    const std::string sharp1("#1");
+    bool first = std::mismatch(sharp1.rbegin(), sharp1.rend(), node_name.rbegin()).first == sharp1.rend();
+
+    if (first) {
+      (*node_type_map_count)[node_type] += 1;
+      (*node_type_map_times_called)[node_type] += detail.times_called / run_count;
+    }
+    (*node_type_map_time)[node_type] += curr_time_val;
+    (*node_type_map_memory)[node_type] += curr_memory_val;
+  }
+}
+
+std::string StatsCalculator::GetStatsByNodeType() const {
+  std::stringstream stream;
+
+  stream << "Number of nodes executed: " << details_.size() << std::endl;
+
+  stream << "============================== Summary by node type "
+            "=============================="
+         << std::endl;
+
+  std::map<std::string, int64_t> node_type_map_count;
+  std::map<std::string, int64_t> node_type_map_time;
+  std::map<std::string, int64_t> node_type_map_memory;
+  std::map<std::string, int64_t> node_type_map_times_called;
+  int64_t accumulated_us = 0;
+
+  ComputeStatsByType(&node_type_map_count, &node_type_map_time,
+                     &node_type_map_memory, &node_type_map_times_called,
+                     &accumulated_us);
+
+  // Sort them.
+  std::priority_queue<std::pair<int64_t, std::pair<std::string, int64_t>>>
+      timings;
+  for (const auto& node_type : node_type_map_time) {
+    const int64_t mem_used = node_type_map_memory[node_type.first];
+    timings.emplace(node_type.second,
+                    std::pair<std::string, int64_t>(node_type.first, mem_used));
+  }
+
+  InitField(stream, 24) << "[Node type]";
+  InitField(stream, 9) << "[count]";
+  InitField(stream, 10) << "[avg ms]";
+  InitField(stream, 11) << "[avg %]";
+  InitField(stream, 11) << "[cdf %]";
+  InitField(stream, 10) << "[mem KB]";
+  InitField(stream, 10) << "[times called]";
+  stream << std::endl;
+
+  float cdf = 0.0f;
+  while (!timings.empty()) {
+    auto entry = timings.top();
+    timings.pop();
+
+    const std::string node_type = entry.second.first;
+    const float memory = entry.second.second / 1000.0f;
+
+    const int64_t node_type_total_us = entry.first;
+    const float time_per_run_ms = node_type_total_us / 1000.0f;
+
+    const float percentage =
+        ((entry.first / static_cast<float>(accumulated_us)) * 100.0f);
+    cdf += percentage;
+
+    InitField(stream, 24) << node_type;
+    InitField(stream, 9) << node_type_map_count[node_type];
+    InitField(stream, 10) << time_per_run_ms;
+    InitField(stream, 10) << percentage << "%";
+    InitField(stream, 10) << cdf << "%";
+    InitField(stream, 10) << memory;
+    InitField(stream, 9) << node_type_map_times_called[node_type];
+    stream << std::endl;
+  }
+  stream << std::endl;
+  return stream.str();
+}
+
+std::string StatsCalculator::GetStatsByMetric(const std::string& title,
+                                              SortingMetric sorting_metric,
+                                              int num_stats) const {
+  std::vector<const Detail*> details;
+  OrderNodesByMetric(sorting_metric, &details);
+
+  double cumulative_stat_on_node = 0;
+
+  std::stringstream stream;
+  stream << HeaderString(title) << std::endl;
+  int stat_num = 0;
+  for (auto detail : details) {
+    ++stat_num;
+    if (num_stats > 0 && stat_num > num_stats) {
+      break;
+    }
+
+    // TODO(andrewharp): Make this keep track of the particular metric for cdf.
+    cumulative_stat_on_node += detail->rel_end_us.sum();
+    stream << ColumnString(*detail, cumulative_stat_on_node, run_total_us_)
+           << std::endl;
+  }
+  stream << std::endl;
+  return stream.str();
+}
+
+std::string StatsCalculator::GetOutputString() const {
+  std::stringstream stream;
+  if (options_.show_run_order) {
+    stream << GetStatsByMetric("Run Order", BY_RUN_ORDER,
+                               options_.run_order_limit);
+  }
+  if (options_.show_time) {
+    stream << GetStatsByMetric("Top by Computation Time", BY_TIME,
+                               options_.time_limit);
+  }
+  if (options_.show_memory) {
+    stream << GetStatsByMetric("Top by Memory Use", BY_MEMORY,
+                               options_.memory_limit);
+  }
+  if (options_.show_type) {
+    stream << GetStatsByNodeType();
+  }
+  if (options_.show_summary) {
+    stream << GetShortSummary() << std::endl;
+  }
+  return stream.str();
+}
+
+void StatsCalculator::AddNodeStats(const std::string& name,
+                                   const std::string& type, int64_t run_order,
+                                   int64_t start_us, int64_t rel_end_us,
+                                   int64_t mem_used) {
+  Detail* detail = nullptr;
+  if (details_.find(name) == details_.end()) {
+    details_.insert({name, {}});
+    detail = &details_.at(name);
+    detail->type = type;
+    detail->name = name;
+    detail->run_order = run_order;
+  } else {
+    detail = &details_.at(name);
+  }
+  detail->start_us.UpdateStat(start_us);
+  detail->rel_end_us.UpdateStat(rel_end_us);
+  detail->mem_used.UpdateStat(mem_used);
+  detail->times_called++;
+}
+
+}  // namespace tensorflow