OpInfo for nn.functional.interpolate (#61956)
authorRichard Zou <zou3519@gmail.com>
Mon, 30 Aug 2021 22:58:50 +0000 (15:58 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 23:00:43 +0000 (16:00 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61956

Each mode goes through a different implementation so they are listed as
different variants.

Test Plan: - run tests

Reviewed By: malfet

Differential Revision: D30013751

Pulled By: zou3519

fbshipit-source-id: 4253b40b55667d7486ef2d98b441c13d807ab292

torch/testing/_internal/common_methods_invocations.py

index 52e8d73..04db52b 100644 (file)
@@ -2535,6 +2535,48 @@ def sample_inputs_hardswish(self, device, dtype, requires_grad):
                requires_grad=requires_grad, low=-5, high=5)) for _ in range(1, N)]
     return tensors
 
+def sample_inputs_interpolate(mode, self, device, dtype, requires_grad):
+    N, C = 2, 3
+    D = 4
+    S = 3
+    L = 5
+
+    align_corners_options: Tuple[Any, ...] = (None,)
+    if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'):
+        align_corners_options = (True, False, None)
+    ranks_for_mode = {
+        'nearest': [1, 2, 3],
+        'linear': [1],
+        'bilinear': [2],
+        'bicubic': [2],
+        'trilinear': [3],
+        'area': [1, 2, 3]
+    }
+
+    def shape(size, rank, with_batch_channel=True):
+        if with_batch_channel:
+            return tuple([N, C] + ([size] * rank))
+        return tuple([size] * rank)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype,
+                       requires_grad=requires_grad, low=-1, high=1)
+
+    sample_inputs = []
+    for align_corners in align_corners_options:
+        for rank in ranks_for_mode[mode]:
+            sample_inputs.extend([
+                SampleInput(make_arg(shape(D, rank)),
+                            args=(shape(S, rank, False), None, mode, align_corners)),
+                SampleInput(make_arg(shape(D, rank)),
+                            args=(shape(L, rank, False), None, mode, align_corners)),
+                SampleInput(make_arg(shape(D, rank)),
+                            args=(None, 1.7, mode, align_corners)),
+                SampleInput(make_arg(shape(D, rank)),
+                            args=(None, 0.6, mode, align_corners)),
+            ])
+
+    return sample_inputs
+
 def sample_inputs_gelu(self, device, dtype, requires_grad):
     N = 5
     tensors = [SampleInput(make_tensor((N * 2, N * 2), device=device, dtype=dtype,
@@ -7227,6 +7269,78 @@ op_db: List[OpInfo] = [
            dtypes=floating_types_and(torch.half),
            sample_inputs_func=sample_inputs_nn_unfold,
            skips=(
+               # JIT alias info internal asserts here
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='nearest',
+           supports_autograd=True,
+           dtypesIfCPU=floating_types_and(torch.uint8),
+           dtypesIfCUDA=floating_types_and(torch.half, torch.uint8),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'),
+           skips=(
+               # JIT alias info internal asserts here
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='linear',
+           supports_autograd=True,
+           dtypesIfCUDA=floating_types_and(torch.half),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'linear'),
+           skips=(
+               # JIT alias info internal asserts here
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='bilinear',
+           supports_autograd=True,
+           dtypesIfCUDA=floating_types_and(torch.half),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'),
+           skips=(
+               # JIT alias info internal asserts here
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='bicubic',
+           supports_autograd=True,
+           dtypesIfCUDA=floating_types_and(torch.half),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           skips=(
+               # JIT alias info internal asserts here
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='trilinear',
+           supports_autograd=True,
+           dtypesIfCUDA=floating_types_and(torch.half),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'),
+           skips=(
+               # JIT alias info internal asserts here
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='area',
+           supports_autograd=True,
+           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'area'),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           skips=(
+               # JIT alias info internal asserts here
                SkipInfo('TestJit', 'test_variant_consistency_jit'),
            ),
            supports_out=False),