[opinfo] nn.functional.pad (#62814)
authorkshitij12345 <kshitijkalambarkar@gmail.com>
Mon, 16 Aug 2021 20:26:46 +0000 (13:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 20:29:34 +0000 (13:29 -0700)
Summary:
Reference: https://github.com/facebookresearch/functorch/issues/78

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62814

Reviewed By: VitalyFedyunin

Differential Revision: D30307492

Pulled By: zou3519

fbshipit-source-id: 4f6062eb4a3c91ed1795df1f82846afa0abafcdc

test/test_nn.py
torch/testing/_internal/common_methods_invocations.py

index b6eb297..ccf6f6e 100644 (file)
@@ -13066,24 +13066,6 @@ class TestNNDeviceType(NNTestCase):
     @onlyOnCPUAndCUDA
     @dtypes(torch.float64, torch.complex128)
     def test_pad(self, device, dtype):
-        inputs = torch.randn(1, 3, 4, 4, device=device, dtype=dtype, requires_grad=True)
-        _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,),
-                                     nondet_tol=GRADCHECK_NONDET_TOL)
-        _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,),
-                                     nondet_tol=GRADCHECK_NONDET_TOL)
-        _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,),
-                                     nondet_tol=GRADCHECK_NONDET_TOL)
-        self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,),
-                                  nondet_tol=GRADCHECK_NONDET_TOL))
-        self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,),
-                                  nondet_tol=GRADCHECK_NONDET_TOL))
-        self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='circular'), (inputs,),
-                                  nondet_tol=GRADCHECK_NONDET_TOL))
-
-        inputs = torch.randn(1, 2, 3, 4, 4, device=device, dtype=dtype, requires_grad=True)
-        self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,),
-                                  nondet_tol=GRADCHECK_NONDET_TOL))
-
         # Assert assertion errors are raised for invalid circular padding values
         inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True)
         # Should raise error when trying to wrap around more than once
index 7e923be..27349a1 100644 (file)
@@ -2765,6 +2765,90 @@ def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
     return list(generator())
 
 
+def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
+    assert mode in ('constant', 'reflect', 'replicate', 'circular')
+    if mode in ['reflect', 'replicate']:
+        cases: tuple = (  # ignore
+            ((1, 3), (1, 2)),
+            ((1, 3), (0, 1)),
+            ((0, 3, 3), (1, 2)),
+            ((0, 3, 3), (0, 1)),
+            ((1, 3, 3), (1, 2)),
+            ((1, 3, 3), (0, 1)),
+            ((1, 3, 3), (0, 2, 0, 1)),
+            ((0, 3, 3, 3), (0, 2, 0, 1)),
+            ((3, 3, 5, 5), (0, 2, 0, 1)),
+            ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 4, 4), (-1, 1, -2, 1)),
+        )
+    elif mode == 'constant':
+        cases = (
+            ((1, 3), (1, 2)),
+            ((1, 3), (0, 1)),
+            ((1, 3), (0, 2, 0, 1)),
+            ((0, 3, 3), (1, 2)),
+            ((0, 3, 3), (0, 1)),
+            ((0, 3, 3), (0, 2, 0, 1)),
+            ((0, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 3), (1, 2)),
+            ((1, 3, 3), (0, 1)),
+            ((1, 3, 3), (0, 2, 0, 1)),
+            ((1, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((0, 3, 3, 3), (1, 2)),
+            ((0, 3, 3, 3), (0, 1)),
+            ((0, 3, 3, 3), (0, 2, 0, 1)),
+            ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((3, 3, 5, 5), (1, 2)),
+            ((3, 3, 5, 5), (0, 1)),
+            ((3, 3, 5, 5), (0, 2, 0, 1)),
+            ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 3, 3, 3), (1, 2)),
+            ((1, 3, 3, 3, 3), (0, 1)),
+            ((1, 3, 3, 3, 3), (0, 2, 0, 1)),
+            ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 4, 4), (-1, 1, -2, 1)),
+        )
+    else:  # mode == 'circular'
+        if dtype == torch.bool:
+            # test_dtypes fails on ASAN with for the case ab
+            # runtime error: load of value 190, which is not a valid value for type 'bool'
+            # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562
+            # Reference Issue: https://github.com/pytorch/pytorch/issues/63034
+            cases = (
+                ((2, 3, 3), (1, 2)),
+                ((1, 3, 3), (1, 2)),
+            )
+        else:
+            cases = (
+                ((0, 3, 3), (1, 2)),
+                ((0, 3, 3), (0, 1)),
+                ((1, 3, 3), (1, 2)),
+                ((1, 3, 3), (0, 1)),
+                ((0, 3, 3, 3), (0, 2, 0, 1)),
+                ((3, 3, 5, 5), (0, 2, 0, 1)),
+                ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+                ((1, 3, 4, 4), (-1, 1, -2, 1)),
+            )
+
+    make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def generator():
+        if mode == 'constant':
+            # Default args
+            yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),))
+
+        if mode in ['reflect', 'replicate', 'circular']:
+            for shape, pad in cases:
+                yield SampleInput(make_inp(shape), args=(pad, mode))
+        else:  # mode == 'constant'
+            for pad_value in (1., 2.):
+                for shape, pad in cases:
+                    yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))
+
+    return list(generator())
+
+
 # TODO: reconcile with torch.linalg.det and torch.linalg.slogdet
 # Creates matrices with a positive nonzero determinant
 def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs):
@@ -6801,6 +6885,52 @@ op_db: List[OpInfo] = [
                SkipInfo('TestJit', 'test_variant_consistency_jit'),
            ),
            supports_out=False,),
+    OpInfo('nn.functional.pad',
+           variant_test_name='constant',
+           aten_name='constant_pad_nd',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'),
+           supports_out=False),
+    OpInfo('nn.functional.pad',
+           variant_test_name='reflect',
+           dtypes=floating_and_complex_types(),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.half),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
+           skips=(
+               # op name not found in JIT graph
+               # There are multiple aten ops, namely reflection_pad_{1,2,3}d
+               # so we can't use aten_name argument in opinfo
+               # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+               SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_out=False),
+    OpInfo('nn.functional.pad',
+           variant_test_name='replicate',
+           dtypes=floating_and_complex_types(),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.half),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'),
+           skips=(
+               # op name not found in JIT graph
+               # There are multiple aten ops, namely replication_pad_{1,2,3}d
+               # so we can't use aten_name argument in opinfo
+               # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+               SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_out=False),
+    OpInfo('nn.functional.pad',
+           variant_test_name='circular',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'),
+           supports_forward_ad=True,
+           check_batched_grad=False,
+           skips=(
+               # Doesn't have a corresponding aten operator.
+               # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+               SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           ),
+           supports_out=False),
     OpInfo('nn.functional.hardswish',
            aten_name="hardswish",
            supports_autograd=True,