Add bfloat16 random_op for CPU.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 1 Mar 2018 02:59:41 +0000 (18:59 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 1 Mar 2018 03:04:28 +0000 (19:04 -0800)
PiperOrigin-RevId: 187418131

tensorflow/core/kernels/random_op.cc
tensorflow/core/lib/random/random_distributions.h
tensorflow/core/lib/random/random_distributions_test.cc

index 78ff794..e372325 100644 (file)
@@ -495,6 +495,7 @@ class RandomGammaOp : public OpKernel {
                           RandomUniformIntOp<CPUDevice, IntType>);
 
 TF_CALL_half(REGISTER);
+TF_CALL_bfloat16(REGISTER);
 TF_CALL_float(REGISTER);
 TF_CALL_double(REGISTER);
 TF_CALL_int32(REGISTER_INT);
index 3fe1f9b..2ebe608 100644 (file)
@@ -32,6 +32,8 @@ namespace random {
 
 // Helper function to convert a 16-bit integer to a half between [0..1).
 PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x);
+// Helper function to convert a 16-bit integer to a bfloat16 between [0..1).
+PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x);
 // Helper function to convert a 32-bit integer to a float between [0..1).
 PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x);
 // Helper function to convert two 32-bit integers to a double between [0..1).
@@ -76,6 +78,30 @@ class UniformDistribution<Generator, Eigen::half> {
 };
 
 template <class Generator>
+class UniformDistribution<Generator, bfloat16> {
+ public:
+  // The number of elements that will be returned.
+  static const int kResultElementCount = Generator::kResultElementCount;
+  // Cost of generation of a single element (in cycles).
+  static const int kElementCost = 3;
+  // Indicate that this distribution may take variable number of samples
+  // during the runtime.
+  static const bool kVariableSamplesPerOutput = false;
+  typedef Array<bfloat16, kResultElementCount> ResultType;
+  typedef bfloat16 ResultElementType;
+
+  PHILOX_DEVICE_INLINE
+  ResultType operator()(Generator* gen) {
+    typename Generator::ResultType sample = (*gen)();
+    ResultType result;
+    for (int i = 0; i < kResultElementCount; ++i) {
+      result[i] = Uint16ToGfloat16(sample[i]);
+    }
+    return result;
+  }
+};
+
+template <class Generator>
 class UniformDistribution<Generator, float> {
  public:
   // The number of elements that will be returned.
@@ -306,6 +332,36 @@ class NormalDistribution<Generator, Eigen::half> {
 };
 
 template <class Generator>
+class NormalDistribution<Generator, bfloat16> {
+ public:
+  // The number of elements that will be returned.
+  static const int kResultElementCount = Generator::kResultElementCount;
+  // Cost of generation of a single element (in cycles).
+  static const int kElementCost = 70;
+  // Indicate that this distribution may take variable number of samples
+  // during the runtime.
+  static const bool kVariableSamplesPerOutput = false;
+  typedef Array<bfloat16, kResultElementCount> ResultType;
+  typedef bfloat16 ResultElementType;
+
+  PHILOX_DEVICE_INLINE
+  ResultType operator()(Generator* gen) {
+    typename Generator::ResultType sample = (*gen)();
+    ResultType result;
+    static_assert(kResultElementCount % 2 == 0,
+                  "kResultElementCount should be an even number");
+    for (int i = 0; i < kResultElementCount; i += 2) {
+      float f[2];
+      // Box-Muller transform requires processing 2 elements at a time.
+      BoxMullerFloat(sample[i], sample[i + 1], &f[0], &f[1]);
+      result[i] = bfloat16(f[0]);
+      result[i + 1] = bfloat16(f[1]);
+    }
+    return result;
+  }
+};
+
+template <class Generator>
 class NormalDistribution<Generator, float> {
  public:
   // The number of elements that will be returned.
@@ -414,6 +470,48 @@ class TruncatedNormalDistribution<SingleSampleGenerator, Eigen::half> {
   }
 };
 
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, bfloat16> {
+ public:
+  // The number of elements that will be returned.
+  static const int kResultElementCount =
+      SingleSampleGenerator::kNativeElementCount;
+  // Cost of generation of a single element (in cycles).
+  static const int kElementCost = 90;
+  // Indicate that this distribution may take variable number of samples
+  // during the runtime.
+  static const bool kVariableSamplesPerOutput = true;
+  // The threshold where the normal distribution is truncated.
+  const float kTruncateValue = 2.0f;
+
+  typedef Array<bfloat16, kResultElementCount> ResultType;
+  typedef bfloat16 ResultElementType;
+
+  PHILOX_DEVICE_INLINE
+  ResultType operator()(SingleSampleGenerator* gen) {
+    ResultType results;
+    int index = 0;
+    while (true) {
+      // Repeatedly take samples from the normal distribution, until we have
+      // the desired number of elements that fall within the pre-defined cutoff
+      // threshold.
+      const uint32 x0 = (*gen)();
+      const uint32 x1 = (*gen)();
+      float f[2];
+      BoxMullerFloat(x0, x1, &f[0], &f[1]);
+
+      for (int i = 0; i < 2; ++i) {
+        if (Eigen::numext::abs(f[i]) < kTruncateValue) {
+          results[index++] = bfloat16(f[i]);
+          if (index >= kResultElementCount) {
+            return results;
+          }
+        }
+      }
+    }
+  }
+};
+
 // Partial specialization for float.
 template <class SingleSampleGenerator>
 class TruncatedNormalDistribution<SingleSampleGenerator, float> {
@@ -567,6 +665,27 @@ PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x) {
   return result - Eigen::half(1.0);
 }
 
+// Helper function to convert an 16-bit integer to a bfloat16 between [0..1).
+// This can create a uniform distribution of values between [0..1).
+PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x) {
+  // bfloat are formatted as follows (MSB first):
+  //    sign(1) exponent(8) mantissa(7)
+  // Conceptually construct the following:
+  //    sign == 0
+  //    exponent == 127  -- an excess 127 representation of a zero exponent
+  //    mantissa == 7 random bits
+  const uint16 man = x & 0x7fu;  // 7 bit mantissa
+  const uint16 exp = static_cast<uint16>(127);
+  const uint16 val = (exp << 7) | man;
+
+  bfloat16 result;
+  memcpy(&result, &val, sizeof(val));
+  // The mantissa has an implicit leading 1, so the above code creates a value
+  // in [1, 2). The minus will not cause a rounding that makes the result 1.
+  // Instead it will just be close to 1.
+  return result - bfloat16(1.0);
+}
+
 // Helper function to convert an 32-bit integer to a float between [0..1).
 PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) {
   // IEEE754 floats are formatted as follows (MSB first):
index 85d68f4..8868672 100644 (file)
@@ -37,6 +37,10 @@ namespace {
 // unit normal distribution, it should almost definitely never exceed 6.
 static constexpr float kZLimit = 6.0;
 
+// As bfloat16 has much less precision, the largest z-value will should be
+// larger than float32.
+static constexpr float kZLimitBfloat16 = 20.0;
+
 // A utility function to fill the given array with samples from the given
 // distribution, using the single adapter of the underlying generator
 template <class Distribution>
@@ -93,7 +97,7 @@ bool CheckSamplesMoments(const std::vector<T>& samples,
       // mode, given the large number of samples.
       moments_data[i] += moment;
       ++moments_sample_count_data[i];
-      moment *= samples_data[index];
+      moment *= static_cast<double>(samples_data[index]);
     }
   }
 
@@ -125,7 +129,7 @@ bool CheckSamplesMoments(const std::vector<T>& samples,
     const double z_test =
         fabs((moments[i] - moments_i_mean) / sqrt(total_variance));
 
-    if (z_test > z_limit) {
+    if (z_test > static_cast<double>(z_limit)) {
       LOG(ERROR) << "failing z_test:"
                  << " moment: " << i << " stride: " << stride
                  << " z_test: " << z_test << " z_limit: " << z_limit
@@ -252,6 +256,22 @@ void RandomParametersMomentsTest(int count, int max_moments,
   }
 }
 
+TEST(PhiloxRandomTest, UniformBfloat16MomentsTest) {
+  const std::vector<int> strides = {0, 1, 4, 17};
+  UniformMomentsTest<bfloat16>(1 << 20, 40, strides, bfloat16(kZLimitBfloat16));
+}
+
+TEST(PhiloxRandomTest, NormalBfloat16MomentsTest) {
+  const std::vector<int> strides = {0, 1, 4, 17};
+  NormalMomentsTest<bfloat16>(8 << 20, 25, strides, bfloat16(kZLimitBfloat16));
+}
+
+TEST(PhiloxRandomTest, RandomParametersBfloat16MomentsTest) {
+  const std::vector<int> strides = {0, 1, 4, 17};
+  RandomParametersMomentsTest<bfloat16>(1 << 20, 40, strides,
+                                        bfloat16(kZLimitBfloat16));
+}
+
 TEST(PhiloxRandomTest, UniformFloatMomentsTest) {
   const std::vector<int> strides = {0, 1, 4, 17};
   UniformMomentsTest<float>(1 << 20, 40, strides, kZLimit);