Support other types for RandomGenerator (#1831)
author장지섭/동작제어Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Wed, 4 Jul 2018 07:30:34 +0000 (16:30 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 4 Jul 2018 07:30:34 +0000 (16:30 +0900)
This commit supports other types(int32_t, uint8_t) for RandomGenerator.
- Add quantization info to RandomGenerator.
- Change class RandomGenerator to not be the template class.
- Change operator method to template method(generate<>()).

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
include/support/tflite/Diff.h
libs/support/tflite/src/Diff.cpp
tools/tflite_benchmark/src/tflite_benchmark.cc
tools/tflite_run/src/tflite_run.cc

index 7ee1848..f76d08b 100644 (file)
@@ -55,32 +55,41 @@ private:
 };
 
 #include "support/tflite/interp/Builder.h"
+#include "support/tflite/Quantization.h"
 
 #include <random>
 
-template <typename T> class RandomGenerator
+class RandomGenerator
 {
 public:
-  RandomGenerator(int seed, T mean, T stddev) : _rand{seed}, _dist{mean, stddev}
+  RandomGenerator(int seed, float mean, float stddev,
+                  const TfLiteQuantizationParams quantization = make_default_quantization())
+      : _rand{seed}, _dist{mean, stddev}, _quantization{quantization}
   {
     // DO NOTHING
   }
 
-  T operator()(const ::nnfw::util::tensor::Shape &, const ::nnfw::util::tensor::Index &)
+public:
+  template <typename T>
+  T generate(const ::nnfw::util::tensor::Shape &, const ::nnfw::util::tensor::Index &)
   {
-    return (*this)();
+    return generate<T>();
   }
 
-  T operator()(void)
+  template <typename T> T generate(void)
   {
     return _dist(_rand);
   }
 
 private:
   std::minstd_rand _rand;
-  std::normal_distribution<T> _dist;
+  std::normal_distribution<float> _dist;
+  const TfLiteQuantizationParams _quantization;
 };
 
+template <>
+uint8_t RandomGenerator::generate<uint8_t>(void);
+
 // For NNAPI testing
 struct RandomTestParam
 {
@@ -91,8 +100,9 @@ struct RandomTestParam
 class RandomTestRunner
 {
 public:
-  RandomTestRunner(int seed, const RandomTestParam &param)
-    : _randgen{seed, 0.0f, 2.0f}, _param{param}
+  RandomTestRunner(int seed, const RandomTestParam &param,
+                   const TfLiteQuantizationParams quantization = make_default_quantization())
+      : _randgen{seed, 0.0f, 2.0f, quantization}, _param{param}
   {
     // DO NOTHING
   }
@@ -103,7 +113,7 @@ public:
   int run(const nnfw::support::tflite::interp::Builder &builder);
 
 private:
-  RandomGenerator<float> _randgen;
+  RandomGenerator _randgen;
   const RandomTestParam _param;
 
 public:
index 1018478..9f9b3d1 100644 (file)
@@ -234,6 +234,13 @@ bool TfLiteInterpMatchApp::run(::tflite::Interpreter &interp, ::tflite::Interpre
 
 #include "util/tensor/Object.h"
 
+using namespace std::placeholders;
+
+template <> uint8_t RandomGenerator::generate<uint8_t>(void)
+{
+  return static_cast<uint8_t>(_dist(_rand) / _quantization.scale + _quantization.zero_point);
+}
+
 //
 // Random Test Runner
 //
@@ -306,7 +313,11 @@ int RandomTestRunner::run(const nnfw::support::tflite::interp::Builder &builder)
 
     assert(tfl_interp_view.shape() == nnapi_view.shape());
 
-    const nnfw::util::tensor::Object<float> data(tfl_interp_view.shape(), _randgen);
+    auto fp = static_cast<float (RandomGenerator::*)(const ::nnfw::util::tensor::Shape &,
+                                                     const ::nnfw::util::tensor::Index &)>(
+        &RandomGenerator::generate<float>);
+    const nnfw::util::tensor::Object<float> data(tfl_interp_view.shape(),
+                                                 std::bind(fp, _randgen, _1, _2));
 
     assert(tfl_interp_view.shape() == data.shape());
 
index 5f25aea..ba2e628 100644 (file)
@@ -175,11 +175,11 @@ int main(const int argc, char **argv)
         assert(tensor->type == kTfLiteFloat32);
 
         const int seed = 1; /* TODO Add an option for seed value */
-        RandomGenerator<float> randgen{seed, 0.0f, 0.2f};
+        RandomGenerator randgen{seed, 0.0f, 0.2f};
         const float *end = reinterpret_cast<const float *>(tensor->data.raw_const + tensor->bytes);
         for (float *ptr = tensor->data.f; ptr < end; ptr++)
         {
-          *ptr = randgen();
+          *ptr = randgen.generate<float>();
         }
       }
     }
index e04519a..23a2380 100644 (file)
@@ -153,11 +153,11 @@ int main(const int argc, char **argv)
         assert(tensor->type == kTfLiteFloat32);
 
         const int seed = 1; /* TODO Add an option for seed value */
-        RandomGenerator<float> randgen{seed, 0.0f, 0.2f};
+        RandomGenerator randgen{seed, 0.0f, 0.2f};
         const float *end = reinterpret_cast<const float *>(tensor->data.raw_const + tensor->bytes);
         for (float *ptr = tensor->data.f; ptr < end; ptr++)
         {
-          *ptr = randgen();
+          *ptr = randgen.generate<float>();
         }
       }
     }