Add BFloat16 support for cross, tril, triu, tril_indices, triu_indices and cumsum...
authorCaoE <e.cao@intel.com>
Tue, 14 Sep 2021 00:58:20 +0000 (17:58 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 00:59:43 +0000 (17:59 -0700)
Summary:
Add BFloat16 support for cross, tril, triu, tril_indices, triu_indices and cumsum operators on CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62454

Reviewed By: albanD

Differential Revision: D30845805

Pulled By: heitorschueroff

fbshipit-source-id: f83836862e38109ec929e83567133e9e88096b8b

aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/TriangularOps.cpp
aten/src/ATen/native/cpu/CrossKernel.cpp
test/test_tensor_creation_ops.py
torch/testing/_internal/common_methods_invocations.py

index 4712c3d..67ef8b6 100644 (file)
@@ -979,7 +979,7 @@ Tensor tril_indices_cpu(
   //
   // 3. sequential RAM + transpose: create an n X 2 Tensor, fill the Tensor
   //    sequentially, and then transpose it.
-  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "tril_indices", [&]() -> void {
+  AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "tril_indices", [&]() -> void {
     // fill the Tensor with correct values
     scalar_t* result_data = result.data_ptr<scalar_t>();
     int64_t i = 0;
@@ -1017,7 +1017,7 @@ Tensor triu_indices_cpu(
   // create an empty Tensor with correct size
   auto result = at::native::empty_cpu({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
 
-  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void {
+  AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "triu_indices", [&]() -> void {
     // fill the Tensor with correct values
     scalar_t* result_data = result.data_ptr<scalar_t>();
     int64_t i = 0;
index ec1741d..765069b 100644 (file)
@@ -99,7 +99,7 @@ Tensor& tril_cpu_(Tensor &self, int64_t k) {
   Tensor self_c;
   std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true);
   Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
     apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
   });
   if (!inplace) self.copy_(result);
@@ -113,7 +113,7 @@ Tensor& tril_cpu_out(const Tensor& self, int64_t k, Tensor &result) {
   }
   Tensor self_c;
   std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false);
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
     apply_triu_tril<scalar_t, false>(result, self_c, false, k);
   });
   return result;
@@ -134,7 +134,7 @@ Tensor& triu_cpu_(Tensor &self, int64_t k) {
   Tensor self_c;
   std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true);
   Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
     apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
   });
   if (!inplace) self.copy_(result);
@@ -148,7 +148,7 @@ Tensor& triu_cpu_out(const Tensor& self, int64_t k, Tensor &result) {
   }
   Tensor self_c;
   std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false);
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
     apply_triu_tril<scalar_t, true>(result, self_c, false, k);
   });
   return result;
index 55e0229..d5bbc81 100644 (file)
@@ -65,7 +65,7 @@ static void apply_cross(Tensor& result, const Tensor& a, const Tensor& b, const
 }
 
 static void cross_kernel_impl(Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX(result.scalar_type(), "cross", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, result.scalar_type(), "cross", [&]() {
     apply_cross<scalar_t>(result, a, b, dim);
   });
 }
index 2404f02..e698768 100644 (file)
@@ -311,6 +311,21 @@ class TestTensorCreation(TestCase):
         for s, d, dtype in product(shapes, diagonals, dtypes):
             run_test(s, device, d, dtype)
 
+    @onlyCPU
+    def test_triu_tril_bfloat16(self, device):
+        op_funcs = [torch.tril, torch.triu]
+        for op_fun in op_funcs:
+            input = torch.randn(3, 3, dtype=torch.float32, device=device).bfloat16().requires_grad_(True)
+            input2 = input.detach().clone().float().requires_grad_(True)
+            out = op_fun(input)
+            out.sum().backward()
+            out2 = op_fun(input2)
+            out2.sum().backward()
+            self.assertEqual(out.dtype, torch.bfloat16)
+            self.assertEqual(input.grad.dtype, torch.bfloat16)
+            self.assertEqual(out, out2.bfloat16())
+            self.assertEqual(input.grad, input2.grad.bfloat16(), atol=0.01, rtol=0)
+
     def test_diagflat(self, device):
         dtype = torch.float32
         # Basic sanity test
@@ -1213,6 +1228,15 @@ class TestTensorCreation(TestCase):
         self.assertEqual(b.triu(2), output)
         self.assertRaises(RuntimeError, lambda: b.triu_(2))
 
+    @onlyCPU
+    def test_triu_tril_indices_bfloat16(self, device):
+        op_funcs = [torch.tril_indices, torch.triu_indices]
+        for op_fun in op_funcs:
+            out = op_fun(4, 3, 1, dtype=torch.bfloat16)
+            out2 = op_fun(4, 3, 1, dtype=torch.float)
+            self.assertEqual(out.dtype, torch.bfloat16)
+            self.assertEqual(out, out2.bfloat16())
+
     # TODO: update to work on CUDA, too
     @onlyCPU
     def test_stack(self, device):
index a2b9fea..a4281bb 100644 (file)
@@ -6437,6 +6437,7 @@ op_db: List[OpInfo] = [
            skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),)),
     OpInfo('cross',
            dtypes=all_types_and_complex(),
+           dtypesIfCPU=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=all_types_and_complex_and(torch.half),
            sample_inputs_func=sample_inputs_cross,
            supports_forward_ad=True,
@@ -9012,10 +9013,12 @@ op_db: List[OpInfo] = [
            supports_forward_ad=True,
            sample_inputs_func=sample_inputs_transpose_swapdims),
     OpInfo('tril',
+           dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
            dtypes=all_types_and_complex_and(torch.bool, torch.half),
            supports_forward_ad=True,
            sample_inputs_func=sample_inputs_tril_triu),
     OpInfo('triu',
+           dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
            dtypes=all_types_and_complex_and(torch.bool, torch.half),
            supports_forward_ad=True,
            sample_inputs_func=sample_inputs_tril_triu),