add BFloat16 support for fold and unfold on CPU (#62880)
authorCaoE <e.cao@intel.com>
Tue, 31 Aug 2021 02:12:23 +0000 (19:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 02:14:10 +0000 (19:14 -0700)
Summary:
Add BFloat16 support for fold and unfold operators on CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62880

Reviewed By: iramazanli

Differential Revision: D30576387

Pulled By: zou3519

fbshipit-source-id: c48f6e56702bfea34448db1b3a1634c49c5d8ec8

aten/src/ATen/native/Col2Im.cpp
aten/src/ATen/native/Im2Col.cpp
test/test_nn.py
torch/testing/_internal/common_methods_invocations.py

index e1cc31d..7e11b1b 100644 (file)
@@ -136,7 +136,7 @@ static void col2im_out_cpu_template(
   output.resize_({batch_size, n_output_plane, output_height, output_width});
   output.zero_();
 
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf,
       input.scalar_type(), "col2im_out_cpu", [&] {
         Tensor input_n = Tensor();
         Tensor output_n = Tensor();
index 0970095..586b961 100644 (file)
@@ -86,7 +86,7 @@ static void im2col_out_cpu_template(
   output.resize_({batch_size, n_output_plane, output_length});
   output.zero_();
 
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf,
       input.scalar_type(), "im2col_out_cpu", [&] {
         Tensor input_n;
         Tensor output_n;
index c6d0e78..96321ba 100644 (file)
@@ -17438,14 +17438,30 @@ class TestNNDeviceType(NNTestCase):
             m(input)
 
     def test_fold(self, device):
+        def test_dtype(fn, input, dtype):
+            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
+            input2 = input.detach().clone().float().requires_grad_(True)
+            out = fn(input)
+            out.sum().backward()
+            out2 = fn(input2)
+            out2.sum().backward()
+            self.assertEqual(out.dtype, dtype)
+            self.assertEqual(input.grad.dtype, dtype)
+            self.assertEqual(out, out2.to(dtype=dtype), atol=0.05, rtol=0)
+            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
+
         def func(x):
             return F.fold(x, output_size=(4, 5), kernel_size=(2, 2))
+
         seeds = (44, 83, 71, 25, 999)
         for sd in seeds:
             torch.manual_seed(sd)
             x = torch.randn(1, 12, 12, device=device, requires_grad=True)
             gradcheck(func, [x])
             gradgradcheck(func, [x])
+            if device == 'cpu':
+                test_dtype(func, x, torch.bfloat16)
+
 
     def test_logsigmoid_out(self, device):
         # this isn't actually documented, but was broken previously:
index 04db52b..e7d9380 100644 (file)
@@ -7267,6 +7267,7 @@ op_db: List[OpInfo] = [
     OpInfo('nn.functional.unfold',
            aten_name='im2col',
            dtypes=floating_types_and(torch.half),
+           dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16),
            sample_inputs_func=sample_inputs_nn_unfold,
            skips=(
                # JIT alias info internal asserts here