fix syntax error in bfloat16 PR (#64122)
authorRishi Puri <puririshi98@berkeley.edu>
Tue, 31 Aug 2021 20:47:29 +0000 (13:47 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 21:33:12 +0000 (14:33 -0700)
Summary:
fixes prior syntax error from PR ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64122

Reviewed By: H-Huang

Differential Revision: D30643596

Pulled By: ngimel

fbshipit-source-id: 0a2d5a40fb6dc7339cd03112e57ef0e1bf8a000e

test/test_cuda.py
torch/cuda/__init__.py

index 70f5a6e..6f742ec 100644 (file)
@@ -46,12 +46,15 @@ if not TEST_CUDA:
 TEST_LARGE_TENSOR = TEST_CUDA
 TEST_MEDIUM_TENSOR = TEST_CUDA
 TEST_CUDNN = TEST_CUDA
+TEST_BF16 = False
 if TEST_CUDA:
     torch.ones(1).cuda()  # initialize cuda context
     TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or
                                 torch.backends.cudnn.is_acceptable(torch.tensor(1., device=torch.device('cuda:0'))))
     TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
     TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9
+    TEST_BF16 = torch.cuda.is_bf16_supported()
+
 
 types = [
     torch.FloatTensor,
@@ -2707,9 +2710,9 @@ torch.cuda.synchronize()
 
         if add_kwargs is None:
             add_kwargs = {}
-
+        fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
         self.assertFalse(torch.is_autocast_enabled())
-        with torch.autocast('cuda', ):
+        with torch.autocast('cuda', dtype=fast_dtype):
             self.assertTrue(torch.is_autocast_enabled())
 
             out_type = out_type if out_type is not None else run_as_type
@@ -2785,6 +2788,27 @@ torch.cuda.synchronize()
                     self._run_autocast_outofplace(op, args, torch.float16)
 
     @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
+    def test_autocast_torch_bf16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op_with_args in self.autocast_lists.torch_fp16:
+                skip_test = False
+                op, args = op_with_args[0], op_with_args[1]
+                if len(op_with_args) == 3:
+                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
+                should_error_from_not_implemented = 'cudnn' in op or 'prelu' in op or 'thnn' in op \
+                    or 'fused' in op or 'gru' in op or op == '_thnn_fused_lstm_cell' or op == 'lstm_cell'
+                if not skip_test:
+                    if should_error_from_not_implemented:
+                        with self.assertRaises(RuntimeError, msg=str(op) + ' should not be supported for bfloat16!'):
+                            self._run_autocast_outofplace(op, args, torch.bfloat16)
+                    else:
+                        if torch.cuda.is_bf16_supported():
+                            self._run_autocast_outofplace(op, args, torch.bfloat16)
+                        else:
+                            with self.assertRaisesRegex(RuntimeError, 'Device does not support bfloat16'):
+                                self._run_autocast_outofplace(op, args, torch.bfloat16)
+
+    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
     def test_autocast_torch_fp32(self):
         for op_with_args in self.autocast_lists.torch_fp32:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
@@ -2806,6 +2830,18 @@ torch.cuda.synchronize()
             for op, args in self.autocast_lists.nn_fp16:
                 self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._nn)
 
+
+
+    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
+    def test_autocast_nn_bf16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op, args in self.autocast_lists.nn_fp16:
+                if torch.cuda.is_bf16_supported():
+                    self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn)
+                else:
+                    with self.assertRaisesRegex(RuntimeError, 'Device does not support bfloat16'):
+                        self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn)
+
     @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
     def test_autocast_nn_fp32(self):
         for op, args in self.autocast_lists.nn_fp32:
index 924782d..80d9e10 100644 (file)
@@ -83,7 +83,8 @@ def is_bf16_supported():
     r"""Returns a bool indicating if the current CUDA device supports dtype bfloat16"""
     cu_vers = torch.version.cuda
     if cu_vers is not None:
-        cuda_maj_decide = int(cu_vers.split(',')[0]) >= 11
+        cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11
+
     else:
         cuda_maj_decide = False
     return torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 and cuda_maj_decide