add torch.meshgrid() OpInfo (#62720)
authorMichael Dagitses <mikeyd@fb.com>
Tue, 17 Aug 2021 11:03:02 +0000 (04:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 11:04:24 +0000 (04:04 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/62719

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62720

Reviewed By: astaff

Differential Revision: D30344574

Pulled By: dagitses

fbshipit-source-id: ed42d9fe20741df98018efb08e640fca370583fb

torch/testing/_internal/common_methods_invocations.py

index 27349a1..5a878b3 100644 (file)
@@ -4149,6 +4149,45 @@ def sample_inputs_matmul(op_info, device, dtype, requires_grad):
     return tuple(sample_inputs)
 
 
+def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype,
+                           requires_grad: bool,
+                           *, variant: str) -> List[SampleInput]:
+    if variant == 'variadic':
+        def make_inputs(
+                tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor,
+                                                            List[torch.Tensor]],
+                                                      Tuple[torch.Tensor, ...]]:
+            return tensors[0], tuple(tensors[1:])
+    elif variant == 'list':
+        def make_inputs(
+                tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor,
+                                                            List[torch.Tensor]],
+                                                      Tuple[torch.Tensor, ...]]:
+            return tensors, ()
+    else:
+        raise ValueError(
+            'Unsupported variant, must be one of {"variadic", "list"}. '
+            f'Got "{variant}".')
+
+    SCALAR = torch.Size([])
+    VECTOR = torch.Size([3])
+    test_cases: List[List[torch.Size]] = [
+        [SCALAR],
+        [VECTOR],
+        [VECTOR, SCALAR],
+        [VECTOR, SCALAR, VECTOR],
+        [VECTOR, SCALAR, VECTOR, SCALAR],
+    ]
+
+    sample_inputs = []
+    for shapes in test_cases:
+        input, args = make_inputs(
+            [make_tensor(shape, device, dtype, requires_grad=requires_grad)
+             for shape in shapes])
+        sample_inputs.append(SampleInput(input=input, args=args))
+    return sample_inputs
+
+
 def sample_inputs_polar(op_info, device, dtype, requires_grad, **kwargs):
     def _make_tensor_helper(shape, low=None, high=None):
         return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
@@ -6759,6 +6798,42 @@ op_db: List[OpInfo] = [
                SkipInfo('TestGradients', 'test_fn_grad'),
                SkipInfo('TestGradients', 'test_fn_gradgrad'),
                SkipInfo('TestGradients', 'test_forward_mode_AD'))),
+    OpInfo('meshgrid',
+           variant_test_name='variadic_tensors',
+           # Our implementation corresponds to "ij" indexing for
+           # numpy.meshgrid, but its default value is "xy".
+           ref=lambda *tensors: np.meshgrid(*tensors, indexing='ij'),
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
+           sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'),
+           skips=[
+               # JIT does not support variadic tensors.
+               SkipInfo('TestJit', 'test_variant_consistency_jit'),
+               # meshgrid is defined in torch.functional to take a
+               # variadic list of tensors. Variadic parameters are not
+               # compatible with the normalize operator tests.
+               SkipInfo('TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # Skip operator schema test because this is a functional and not an operator
+               SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+           ],
+           supports_out=False,
+           supports_forward_ad=True),
+    OpInfo('meshgrid',
+           variant_test_name='list_of_tensors',
+           # Unlike the variant above, we do not use np.meshgrid as a
+           # ref since it does not officially support list of numpy
+           # arrays.
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
+           sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'),
+           skips=[
+               # meshgrid is defined in torch.functional to take a
+               # variadic list of tensors. Variadic parameters are not
+               # compatible with the normalize operator tests.
+               SkipInfo('TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+           ],
+           assert_autodiffed=True,
+           supports_out=False,
+           autodiff_nonfusible_nodes=[],
+           supports_forward_ad=True),
     OpInfo('min',
            op=torch.min,
            variant_test_name='binary',