TEST_LARGE_TENSOR = TEST_CUDA
TEST_MEDIUM_TENSOR = TEST_CUDA
TEST_CUDNN = TEST_CUDA
+TEST_BF16 = False
if TEST_CUDA:
torch.ones(1).cuda() # initialize cuda context
TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or
torch.backends.cudnn.is_acceptable(torch.tensor(1., device=torch.device('cuda:0'))))
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9
+ TEST_BF16 = torch.cuda.is_bf16_supported()
+
types = [
torch.FloatTensor,
if add_kwargs is None:
add_kwargs = {}
-
+ fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
self.assertFalse(torch.is_autocast_enabled())
- with torch.autocast('cuda', ):
+ with torch.autocast('cuda', dtype=fast_dtype):
self.assertTrue(torch.is_autocast_enabled())
out_type = out_type if out_type is not None else run_as_type
self._run_autocast_outofplace(op, args, torch.float16)
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
+ def test_autocast_torch_bf16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op_with_args in self.autocast_lists.torch_fp16:
+ skip_test = False
+ op, args = op_with_args[0], op_with_args[1]
+ if len(op_with_args) == 3:
+ skip_test = op_with_args[2] # TEST_WITH_ROCM
+ should_error_from_not_implemented = 'cudnn' in op or 'prelu' in op or 'thnn' in op \
+ or 'fused' in op or 'gru' in op or op == '_thnn_fused_lstm_cell' or op == 'lstm_cell'
+ if not skip_test:
+ if should_error_from_not_implemented:
+ with self.assertRaises(RuntimeError, msg=str(op) + ' should not be supported for bfloat16!'):
+ self._run_autocast_outofplace(op, args, torch.bfloat16)
+ else:
+ if torch.cuda.is_bf16_supported():
+ self._run_autocast_outofplace(op, args, torch.bfloat16)
+ else:
+ with self.assertRaisesRegex(RuntimeError, 'Device does not support bfloat16'):
+ self._run_autocast_outofplace(op, args, torch.bfloat16)
+
+ @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
for op, args in self.autocast_lists.nn_fp16:
self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._nn)
+
+
+ @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
+ def test_autocast_nn_bf16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op, args in self.autocast_lists.nn_fp16:
+ if torch.cuda.is_bf16_supported():
+ self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn)
+ else:
+ with self.assertRaisesRegex(RuntimeError, 'Device does not support bfloat16'):
+ self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn)
+
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_nn_fp32(self):
for op, args in self.autocast_lists.nn_fp32: