add BFloat16 support for bernoulli and Dropout on CPU (#56372)
authormingfeima <mingfei.ma@intel.com>
Wed, 25 Aug 2021 18:53:52 +0000 (11:53 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 19:01:27 +0000 (12:01 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56372

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D28836792

Pulled By: VitalyFedyunin

fbshipit-source-id: ede951d172a59276e11383fd767778ab959b5a6b

aten/src/ATen/native/cpu/DistributionTemplates.h
aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
test/test_nn.py
test/test_torch.py

index 66bd31f..15b1916 100644 (file)
@@ -308,7 +308,7 @@ struct ExponentialKernel {
 
 template<typename RNG>
 void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
     // See Note [Acquire lock when using random generators]
     std::lock_guard<std::mutex> lock(generator->mutex_);
     using self_t = scalar_t;
@@ -325,7 +325,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
         return static_cast<self_t>(bernoulli(generator));
       });
     } else {
-      AT_DISPATCH_FLOATING_TYPES(p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
         using p_t = scalar_t;
         cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
           at::bernoulli_distribution<float> bernoulli(p_val);
@@ -338,7 +338,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
 
 template<typename RNG>
 void bernoulli_kernel(Tensor& self, double p, RNG generator) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
     // See Note [Acquire lock when using random generators]
     std::lock_guard<std::mutex> lock(generator->mutex_);
     auto iter = TensorIterator::borrowing_nullary_op(self);
index 007e444..f86f0a3 100644 (file)
@@ -488,7 +488,7 @@ void bernoulli_scalar_kernel(Tensor &self, double p, c10::optional<Generator> ge
     int64_t n = self.numel();
     bool contig = self.is_contiguous();
 
-    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
+    AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
       at::Tensor tmp_int_tensor;
       if (std::is_same<scalar_t, int>::value && contig) {
         tmp_int_tensor = self;
index d577493..8c3541a 100644 (file)
@@ -12984,7 +12984,7 @@ class TestNNDeviceType(NNTestCase):
 
         self._test_dropout_stride_mean_preserve(nn.Dropout, device)
 
-        if self.device_type == 'cuda':
+        if self.device_type == 'cuda' or self.device_type == 'cpu':
             input = input.bfloat16()
             self._test_dropout(nn.Dropout, device, input)
 
index d0f631a..15e36c8 100644 (file)
@@ -4324,6 +4324,7 @@ else:
             self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
 
     @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
+    @dtypesIfCPU(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=True)))
     @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
     def test_bernoulli_p(self, device, dtype):
         for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]):