// Helper function to convert a 16-bit integer to a half between [0..1).
PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x);
+// Helper function to convert a 16-bit integer to a bfloat16 between [0..1).
+PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x);
// Helper function to convert a 32-bit integer to a float between [0..1).
PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x);
// Helper function to convert two 32-bit integers to a double between [0..1).
};
template <class Generator>
+class UniformDistribution<Generator, bfloat16> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static const int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<bfloat16, kResultElementCount> ResultType;
+ typedef bfloat16 ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i) {
+ result[i] = Uint16ToGfloat16(sample[i]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
class UniformDistribution<Generator, float> {
public:
// The number of elements that will be returned.
};
template <class Generator>
+class NormalDistribution<Generator, bfloat16> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static const int kElementCost = 70;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<bfloat16, kResultElementCount> ResultType;
+ typedef bfloat16 ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ static_assert(kResultElementCount % 2 == 0,
+ "kResultElementCount should be an even number");
+ for (int i = 0; i < kResultElementCount; i += 2) {
+ float f[2];
+ // Box-Muller transform requires processing 2 elements at a time.
+ BoxMullerFloat(sample[i], sample[i + 1], &f[0], &f[1]);
+ result[i] = bfloat16(f[0]);
+ result[i + 1] = bfloat16(f[1]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
class NormalDistribution<Generator, float> {
public:
// The number of elements that will be returned.
}
};
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, bfloat16> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount =
+ SingleSampleGenerator::kNativeElementCount;
+ // Cost of generation of a single element (in cycles).
+ static const int kElementCost = 90;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = true;
+ // The threshold where the normal distribution is truncated.
+ const float kTruncateValue = 2.0f;
+
+ typedef Array<bfloat16, kResultElementCount> ResultType;
+ typedef bfloat16 ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator* gen) {
+ ResultType results;
+ int index = 0;
+ while (true) {
+ // Repeatedly take samples from the normal distribution, until we have
+ // the desired number of elements that fall within the pre-defined cutoff
+ // threshold.
+ const uint32 x0 = (*gen)();
+ const uint32 x1 = (*gen)();
+ float f[2];
+ BoxMullerFloat(x0, x1, &f[0], &f[1]);
+
+ for (int i = 0; i < 2; ++i) {
+ if (Eigen::numext::abs(f[i]) < kTruncateValue) {
+ results[index++] = bfloat16(f[i]);
+ if (index >= kResultElementCount) {
+ return results;
+ }
+ }
+ }
+ }
+ }
+};
+
// Partial specialization for float.
template <class SingleSampleGenerator>
class TruncatedNormalDistribution<SingleSampleGenerator, float> {
return result - Eigen::half(1.0);
}
+// Helper function to convert an 16-bit integer to a bfloat16 between [0..1).
+// This can create a uniform distribution of values between [0..1).
+PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x) {
+ // bfloat are formatted as follows (MSB first):
+ // sign(1) exponent(8) mantissa(7)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 127 -- an excess 127 representation of a zero exponent
+ // mantissa == 7 random bits
+ const uint16 man = x & 0x7fu; // 7 bit mantissa
+ const uint16 exp = static_cast<uint16>(127);
+ const uint16 val = (exp << 7) | man;
+
+ bfloat16 result;
+ memcpy(&result, &val, sizeof(val));
+ // The mantissa has an implicit leading 1, so the above code creates a value
+ // in [1, 2). The minus will not cause a rounding that makes the result 1.
+ // Instead it will just be close to 1.
+ return result - bfloat16(1.0);
+}
+
// Helper function to convert an 32-bit integer to a float between [0..1).
PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) {
// IEEE754 floats are formatted as follows (MSB first):
// unit normal distribution, it should almost definitely never exceed 6.
static constexpr float kZLimit = 6.0;
+// As bfloat16 has much less precision, the largest z-value will should be
+// larger than float32.
+static constexpr float kZLimitBfloat16 = 20.0;
+
// A utility function to fill the given array with samples from the given
// distribution, using the single adapter of the underlying generator
template <class Distribution>
// mode, given the large number of samples.
moments_data[i] += moment;
++moments_sample_count_data[i];
- moment *= samples_data[index];
+ moment *= static_cast<double>(samples_data[index]);
}
}
const double z_test =
fabs((moments[i] - moments_i_mean) / sqrt(total_variance));
- if (z_test > z_limit) {
+ if (z_test > static_cast<double>(z_limit)) {
LOG(ERROR) << "failing z_test:"
<< " moment: " << i << " stride: " << stride
<< " z_test: " << z_test << " z_limit: " << z_limit
}
}
+TEST(PhiloxRandomTest, UniformBfloat16MomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ UniformMomentsTest<bfloat16>(1 << 20, 40, strides, bfloat16(kZLimitBfloat16));
+}
+
+TEST(PhiloxRandomTest, NormalBfloat16MomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ NormalMomentsTest<bfloat16>(8 << 20, 25, strides, bfloat16(kZLimitBfloat16));
+}
+
+TEST(PhiloxRandomTest, RandomParametersBfloat16MomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ RandomParametersMomentsTest<bfloat16>(1 << 20, 40, strides,
+ bfloat16(kZLimitBfloat16));
+}
+
TEST(PhiloxRandomTest, UniformFloatMomentsTest) {
const std::vector<int> strides = {0, 1, 4, 17};
UniformMomentsTest<float>(1 << 20, 40, strides, kZLimit);