OpInfo: `nn.functional.conv_transpose2d` (#62882)
authorKushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Sat, 14 Aug 2021 00:10:07 +0000 (17:10 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 14 Aug 2021 00:11:23 +0000 (17:11 -0700)
Summary:
See https://github.com/facebookresearch/functorch/issues/78 and https://github.com/pytorch/pytorch/issues/54261.

cc: mruberry zou3519 Chillee

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62882

Reviewed By: bdhirsh

Differential Revision: D30280804

Pulled By: zou3519

fbshipit-source-id: e40cdf43e98c1f11e45df6b8bc13110b4d29c45f

torch/testing/_internal/common_methods_invocations.py

index 36e9b9a..14fd5d1 100644 (file)
@@ -2261,6 +2261,33 @@ def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs):
 
     return list(generator())
 
+def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as shapes for input, weight, bias
+    # and a dict of values of (stride, padding, output_padding, groups, dilation)
+    cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
+         {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}),
+        ((2, 2, 4, 4), (2, 2, 4, 5), (4,),
+         {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}),
+        ((1, 1, 4, 5), (1, 1, 4, 3), (1,),
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}),
+        ((1, 1, 4, 3), (1, 2, 3, 4), None,
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
+        ((1, 4, 5, 5), (4, 8, 3, 3), None,
+         {})
+    )
+
+    def generator():
+        for input_shape, weight, bias, kwargs in cases:
+            yield SampleInput(make_arg(input_shape), args=(
+                make_arg(weight),
+                make_arg(bias) if bias is not None else bias
+            ), kwargs=kwargs)
+
+    return list(generator())
+
 def sample_inputs_hardswish(self, device, dtype, requires_grad):
     N = 5
     # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
@@ -6745,6 +6772,19 @@ op_db: List[OpInfo] = [
            dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
            sample_inputs_func=sample_inputs_nn_activation_relu,
            supports_out=False),
+    OpInfo('nn.functional.conv_transpose2d',
+           aten_name='conv_transpose2d',
+           aliases=('conv_transpose2d',),
+           dtypesIfCPU=floating_types_and(torch.int64),
+           dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
+           sample_inputs_func=sample_inputs_conv_transpose2d,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           skips=(
+               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False,),
     OpInfo('nn.functional.hardswish',
            aten_name="hardswish",
            supports_autograd=True,