Adjust out of range values generated by RandomGenerator (#9418)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Fri, 6 Dec 2019 09:48:04 +0000 (18:48 +0900)
committer이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Fri, 6 Dec 2019 09:48:04 +0000 (18:48 +0900)
This commit adjusts out of range values generated by RandomGenerator to be mapped to end points of the range.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtime/libs/tflite/src/Diff.cpp

index 1a3ac85..8ddad6c 100644 (file)
@@ -270,7 +270,24 @@ template <> uint8_t RandomGenerator::generate<uint8_t>(void)
   // Most _dist values range from -5.0 to 5.0.
   float min_range = -5.0f;
   float max_range = 5.0f;
-  return static_cast<uint8_t>((_dist(_rand) - min_range) * type_range / (max_range - min_range));
+  // NOTE shifted_relative_val has Gaussian distribution that origin mean was 0 and standard
+  // deviation was 2. And then its values are distributed and shift to that mean is 127.5 and range
+  // is about [0, 255].
+  float shifted_relative_val = (_dist(_rand) - min_range) * type_range / (max_range - min_range);
+
+  // shifted_relative_val is adjusted to be mapped to end points of the range, if it is out of range
+  // values.
+  if (shifted_relative_val < 0.0f)
+  {
+    return 0;
+  }
+  else if (shifted_relative_val > type_range)
+  {
+    return 255;
+  }
+
+  // Convert shifted_relative_val from float to uint8
+  return static_cast<uint8_t>(shifted_relative_val);
 }
 
 #include "tflite/TensorLogger.h"