Extract RandomGenerator from RandomTestRunner (#1445)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Thu, 31 May 2018 04:20:47 +0000 (13:20 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 31 May 2018 04:20:47 +0000 (13:20 +0900)
In `nnfw_util` module, extract RandomGenerator from RandomTestRunner
so RandomGenerator can be used in `tflite_run` as well.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
include/support/tflite/Diff.h
libs/support/tflite/src/Diff.cpp

index ebc3ac3..26d141a 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "util/tensor/Index.h"
 #include "util/tensor/Diff.h"
+#include "util/tensor/Shape.h"
 
 #include "support/tflite/TensorView.h"
 
@@ -90,6 +91,24 @@ private:
 
 #include <random>
 
+template <typename T> class RandomGenerator
+{
+public:
+  RandomGenerator(int seed, T mean, T stddev) : _rand{seed}, _dist{mean, stddev}
+  {
+    // DO NOTHING
+  }
+
+  T operator()(const ::nnfw::util::tensor::Shape &, const ::nnfw::util::tensor::Index &)
+  {
+    return _dist(_rand);
+  }
+
+private:
+  std::minstd_rand _rand;
+  std::normal_distribution<T> _dist;
+};
+
 // For NNAPI testing
 struct RandomTestParam
 {
@@ -101,7 +120,7 @@ class RandomTestRunner
 {
 public:
   RandomTestRunner(int seed, const RandomTestParam &param)
-    : _rand{seed}, _param{param}
+    : _randgen{seed, 0.0f, 2.0f}, _param{param}
   {
     // DO NOTHING
   }
@@ -112,7 +131,7 @@ public:
   int run(const nnfw::support::tflite::interp::Builder &builder);
 
 private:
-  std::minstd_rand _rand;
+  RandomGenerator<float> _randgen;
   const RandomTestParam _param;
 };
 
index 2a07389..4b994eb 100644 (file)
@@ -212,12 +212,6 @@ int RandomTestRunner::run(const nnfw::support::tflite::interp::Builder &builder)
   assert(pure->inputs() == nnapi->inputs());
 
   // Fill IFM with random numbers
-  auto ifm_gen = [this](const nnfw::util::tensor::Shape &, const nnfw::util::tensor::Index &) {
-    // TODO Allow users to set min/max and distribution
-    std::normal_distribution<float> dist(0.0f, 2.0f);
-    return dist(_rand);
-  };
-
   for (const auto id : pure->inputs())
   {
     assert(pure->tensor(id)->type == nnapi->tensor(id)->type);
@@ -231,7 +225,7 @@ int RandomTestRunner::run(const nnfw::support::tflite::interp::Builder &builder)
 
     assert(pure_view.shape() == nnapi_view.shape());
 
-    const nnfw::util::tensor::Object<float> data(pure_view.shape(), ifm_gen);
+    const nnfw::util::tensor::Object<float> data(pure_view.shape(), _randgen);
 
     assert(pure_view.shape() == data.shape());