add `OpInfo` for `torch.nn.functional.dropout` (#62315)
authorPhilip Meier <github.pmeier@posteo.de>
Wed, 15 Sep 2021 14:16:29 +0000 (07:16 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 14:18:04 +0000 (07:18 -0700)
Summary:
Addresses facebookresearch/functorch#78.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62315

Reviewed By: mruberry

Differential Revision: D30932765

Pulled By: zou3519

fbshipit-source-id: 481c67b59a966b4d640973d252b3e392d8db728e

test/test_fx_experimental.py
torch/testing/_internal/common_methods_invocations.py

index d580d20..0727912 100644 (file)
@@ -1468,6 +1468,7 @@ class TestNormalizeOperators(JitTestCase):
             "igamma",
             "igammac",
             "index_put",
+            "nn.functional.dropout",
             "polygamma",
             "special.polygamma",
             "repeat",
index 5aa8b67..f7294c2 100644 (file)
@@ -5331,6 +5331,16 @@ def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs):
 
     return [SampleInput(tensor, args=args) for tensor, args in test_cases]
 
+def sample_inputs_dropout(op_info, device, dtype, requires_grad, **kwargs):
+    input = make_tensor((S,), device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        SampleInput(input),
+        SampleInput(input, kwargs=dict(p=0.0)),
+        SampleInput(input, kwargs=dict(p=1.0)),
+        SampleInput(input, kwargs=dict(training=False)),
+    ]
+
 def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs):
     def make_input(shape, *, low, high):
         return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad)
@@ -5735,6 +5745,14 @@ def reference_mse_loss(input, target, reduction="mean"):
         return se
 
 
+def wrapper_set_seed(op, input, *args, **kwargs):
+    """Wrapper to set seed manually for some functions like dropout
+    See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
+    """
+    torch.manual_seed(42)
+    return op(input, *args, **kwargs)
+
+
 def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
     feature_size = np.prod(normalized_shape)
     inp_view = inp.reshape(-1, feature_size)  # type: ignore[call-overload]
@@ -9318,6 +9336,30 @@ op_db: List[OpInfo] = [
                    dtypes=all_types_and(torch.bool),
                    safe_casts_outputs=True),
     OpInfo(
+        "nn.functional.dropout",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs),
+        ref=_NOTHING,
+        dtypes=floating_types_and(torch.bfloat16),
+        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            # Probably because we have used lambda for the op here
+            # AssertionError: JIT Test does not execute any logic
+            SkipInfo('TestJit', 'test_variant_consistency_jit'),
+            # inplace variant dispatches to dropout kernel, while on CUDA
+            # the op dispatches to _fused_dropout (with a few more conditions)
+            # hence, different values and this skip here
+            SkipInfo('TestMathBits', 'test_neg_view', device_type='cuda'),
+            # On CUDA, the op is dispatched (and a few more conditions) to
+            # _fused_dropout, which doesn't support forward AD
+            SkipInfo('TestGradients', 'test_forward_mode_AD', device_type='cuda'),),
+        gradcheck_wrapper=wrapper_set_seed,
+        supports_forward_ad=True,
+        supports_out=False,
+        sample_inputs_func=sample_inputs_dropout,
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)),
+    OpInfo(
         "nn.functional.one_hot",
         ref=reference_one_hot,
         supports_out=False,