self.assertEqual(outputs, m(*inputs))
@unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @skipIfRocm
def test_dropout_cuda(self):
# Dropout AD is dispatched to _fused_dropout in CUDA case,
# which is not included in TestJitGeneratedFunctional
mask,
p1m: float):
p1r = 1. / p1m
- if grad.requires_grad:
- grad_input = grad * (mask.type_as(grad) * p1r)
- else:
- grad_input = torch._masked_scale(grad, mask, p1r)
+ grad_input = grad * (mask.type_as(grad) * p1r)
return grad_input
def dropout(input,
p: float,
train: bool):
use_cuda = input.is_cuda
- # CUDA has a fused dropout implementation
+ # lowering is specialized for cuda because cuda fuser can efficiently fuse those operations
+ # for cpu backend, where fusions are disabled, a different lowering that is more efficient
+ # in the absence of fusion is used
p1m = 1. - p
if use_cuda:
- res, mask = torch._fused_dropout(input, p1m)
+ mask = torch.rand_like(input) < p1m
+ res = mask.type_as(input) * input * (1./p1m)
else:
mask = torch.empty_like(input)
mask.bernoulli_(p1m)