add support for indexing to meshgrid (#62722)
authorMichael Dagitses <mikeyd@fb.com>
Thu, 16 Sep 2021 16:58:09 +0000 (09:58 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 16:59:49 +0000 (09:59 -0700)
Summary:
This is step 3/7 of https://github.com/pytorch/pytorch/issues/50276. It only adds support for the argument but doesn't implement new indexing modes yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62722

Test Plan:
Verified this is not FC breaking by adding logging to both meshgrid
overloads and then called meshgrid twice:

`meshgrid(*tensors)`
  and
`meshgrid(*tensors, indexing='ij')`

This confirmed that the former signature triggered the original native
function and the latter signature triggered the new native function.

Reviewed By: H-Huang

Differential Revision: D30394313

Pulled By: dagitses

fbshipit-source-id: e265cb114d8caae414ee2305dc463b34fdb57fa6

aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/native_functions.yaml
test/test_tensor_creation_ops.py
torch/functional.py
torch/onnx/symbolic_opset9.py
torch/testing/_internal/common_methods_invocations.py

index f063944..388684a 100644 (file)
@@ -2153,8 +2153,21 @@ std::vector<Tensor> unbind(const Tensor& self, Dimname dim) {
 }
 
 std::vector<Tensor> meshgrid(TensorList tensors) {
+  TORCH_WARN_ONCE("torch.meshgrid: in an upcoming release, it will be required to pass the "
+                  "indexing argument.");
+  return native::meshgrid(tensors, /*indexing=*/"ij");
+}
+
+std::vector<Tensor> meshgrid(TensorList tensors,
+                             c10::string_view indexing) {
   int64_t size = tensors.size();
   TORCH_CHECK(size > 0, "meshgrid expects a non-empty TensorList");
+
+  TORCH_CHECK(
+      indexing == "ij",
+      "torch.meshgrid: only \"ij\" indexing is supported at this time, but "
+      "received: ", indexing);
+
   std::vector<int64_t> shape(size);
   for(const auto i: c10::irange(size)){
     switch (tensors[i].dim()) {
index edcccaa..afa9af3 100644 (file)
 
 - func: meshgrid(Tensor[] tensors) -> Tensor[]
 
+# TODO: Two weeks after this lands, combine these two overloads,
+#       making "indexing" optional. These are temporarily distinct for
+#       forward-compatibility reasons.
+- func: meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[]
+
 - func: cartesian_prod(Tensor[] tensors) -> Tensor
   variants: function
 
index e698768..511730d 100644 (file)
@@ -1402,7 +1402,39 @@ class TestTensorCreation(TestCase):
         y = x[2:]
         self.assertEqual(int(y), 3)
 
-    def test_meshgrid(self, device):
+    def test_meshgrid_empty(self):
+        with self.assertRaisesRegex(RuntimeError,
+                                    'expects a non-empty TensorList'):
+            torch.meshgrid()
+
+    def test_meshgrid_unsupported_indexing(self):
+        with self.assertRaisesRegex(RuntimeError,
+                                    'only "ij" indexing is supported'):
+            torch.meshgrid(torch.tensor([1, 2]), indexing='')
+
+    def test_meshgrid_non_1d_tensor(self):
+        with self.assertRaisesRegex(RuntimeError,
+                                    'Expected scalar or 1D tensor'):
+            torch.meshgrid(torch.tensor([[1, 2], [3, 4]]))
+
+    def test_meshgrid_inconsistent_dtype(self):
+        with self.assertRaisesRegex(
+                RuntimeError, 'expects all tensors to have the same dtype'):
+            torch.meshgrid(torch.tensor([1], dtype=torch.int),
+                           torch.tensor([2], dtype=torch.float))
+
+    def test_meshgrid_inconsistent_device(self):
+        with self.assertRaisesRegex(
+                RuntimeError, 'expects all tensors to have the same device'):
+            torch.meshgrid(torch.tensor([1], device='cpu'),
+                           torch.tensor([2], device='meta'))
+
+    def test_meshgrid_warns_if_no_indexing(self):
+        with self.assertWarnsOnceRegex(
+                UserWarning, '.*will be required to pass the indexing arg.*'):
+            torch.meshgrid(torch.tensor([1, 2]))
+
+    def test_meshgrid_default_indexing(self, device):
         a = torch.tensor(1, device=device)
         b = torch.tensor([1, 2, 3], device=device)
         c = torch.tensor([1, 2], device=device)
@@ -1428,6 +1460,81 @@ class TestTensorCreation(TestCase):
         self.assertTrue(grid_b2.equal(expected_grid_b))
         self.assertTrue(grid_c2.equal(expected_grid_c))
 
+    def test_meshgrid_ij_indexing(self, device):
+        a = torch.tensor(1, device=device)
+        b = torch.tensor([1, 2, 3], device=device)
+        c = torch.tensor([1, 2], device=device)
+        grid_a, grid_b, grid_c = torch.meshgrid([a, b, c], indexing='ij')
+        self.assertEqual(grid_a.shape, torch.Size([1, 3, 2]))
+        self.assertEqual(grid_b.shape, torch.Size([1, 3, 2]))
+        self.assertEqual(grid_c.shape, torch.Size([1, 3, 2]))
+        grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c, indexing='ij')
+        self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2]))
+        self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2]))
+        self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2]))
+        expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64, device=device)
+        expected_grid_b = torch.tensor([[[1, 1],
+                                         [2, 2],
+                                         [3, 3]]], device=device)
+        expected_grid_c = torch.tensor([[[1, 2],
+                                         [1, 2],
+                                         [1, 2]]], device=device)
+        self.assertTrue(grid_a.equal(expected_grid_a))
+        self.assertTrue(grid_b.equal(expected_grid_b))
+        self.assertTrue(grid_c.equal(expected_grid_c))
+        self.assertTrue(grid_a2.equal(expected_grid_a))
+        self.assertTrue(grid_b2.equal(expected_grid_b))
+        self.assertTrue(grid_c2.equal(expected_grid_c))
+
+    def test_meshgrid_ij_indexing_is_default(self, device):
+        a = torch.tensor(1, device=device)
+        b = torch.tensor([1, 2, 3], device=device)
+        c = torch.tensor([1, 2], device=device)
+        grid_a, grid_b, grid_c = torch.meshgrid(a, b, c, indexing='ij')
+        grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c)
+        self.assertTrue(grid_a.equal(grid_a2))
+        self.assertTrue(grid_b.equal(grid_b2))
+        self.assertTrue(grid_c.equal(grid_c2))
+
+    @skipMeta
+    def test_meshgrid_vs_numpy(self, device):
+        # Shapes to the random tensors. Each line is a test case, and
+        # each list within that line is the shape of a single
+        # tensor. The shapes are restricted to 0D (represented by [])
+        # and 1D tensors.
+        cases = [
+            [[]],
+            [[1], [1], [1]],
+            [[], [], []],
+            [[3], [5], [7]],
+            [[3], [], [7]],
+            [[11], [13]],
+            [[15]],
+        ]
+
+        # We also need to test the different indexing modes. We can't
+        # just enumerate them because we don't presently support the
+        # same modes as numpy.meshgrid, nor does our default
+        # correspond to their default.
+        #
+        # TODO Eliminate this and replace it with a list of all
+        # supported indexing modes when we have full compatibility.
+        indexing_correspondence = [
+            # No indexing in PyTorch corresponds to "ij" indexing in
+            # NumPy.
+            ({}, {'indexing': 'ij'}),
+            # "ij" is implemented identically in both.
+            ({'indexing': 'ij'}, {'indexing': 'ij'}),
+            # TODO Test "xy" when it is supported in PyTorch.
+        ]
+        for shapes, (torch_kwargs, numpy_kwargs) in product(cases, indexing_correspondence):
+            with self.subTest(shapes=shapes, torch_kwargs=torch_kwargs, numpy_kwargs=numpy_kwargs):
+                tensors = [make_tensor(shape, device=device, dtype=torch.int) for shape in shapes]
+                torch_grids = torch.meshgrid(*tensors, **torch_kwargs)
+                numpy_grids = np.meshgrid(*(tensor.cpu().numpy() for tensor in tensors), **numpy_kwargs)
+                self.assertEqual(torch_grids, numpy_grids)
+
+
     def test_cartesian_prod(self, device):
         a = torch.tensor([1], device=device)
         b = torch.tensor([1, 2, 3], device=device)
index a773333..469f45d 100644 (file)
@@ -330,10 +330,11 @@ def einsum(*args):
 # This wrapper exists to support variadic args.
 if TYPE_CHECKING:
     # The JIT doesn't understand Union, so only add type annotation for mypy
-    def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]:
-        return _meshgrid(*tensors)
+    def meshgrid(*tensors: Union[Tensor, List[Tensor]],
+                 indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
+        return _meshgrid(*tensors, indexing=indexing)
 else:
-    def meshgrid(*tensors):
+    def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
         r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
 
         This is helpful when you want to visualize data over some
@@ -351,10 +352,12 @@ else:
             single element.
 
         .. warning::
-            `torch.meshgrid` has the same behavior as calling
-            `numpy.meshgrid(..., indexing='ij')`, and in the future
-            `torch.meshgrid` will also support the `indexing`
-            argument.
+            `torch.meshgrid(*tensors)` currently has the same behavior
+            as calling `numpy.meshgrid(*arrays, indexing='ij')`.
+
+            In the future `torch.meshgrid` will support the
+            `indexing='xy'` and eventually transition to that as the
+            default.
 
             https://github.com/pytorch/pytorch/issues/50276 tracks
             this issue with the goal of migrating to NumPy's behavior.
@@ -368,6 +371,9 @@ else:
             tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
                 treated as tensors of size :math:`(1,)` automatically
 
+            indexing: (str, optional): the indexing mode requested.
+                Only "ij" is currently supported.
+
         Returns:
             seq (sequence of Tensors): If the input has :math:`N`
             tensors of size :math:`S_0 \ldots S_{N-1}``, then the
@@ -382,7 +388,7 @@ else:
             Observe the element-wise pairings across the grid, (1, 4),
             (1, 5), ..., (3, 6). This is the same thing as the
             cartesian product.
-            >>> grid_x, grid_y = torch.meshgrid(x, y)
+            >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
             >>> grid_x
             tensor([[1, 1, 1],
                     [2, 2, 2],
@@ -414,16 +420,22 @@ else:
             :width: 512
 
         """
-        return _meshgrid(*tensors)
+        return _meshgrid(*tensors, indexing=indexing)
 
 
-def _meshgrid(*tensors):
+def _meshgrid(*tensors, indexing: Optional[str]):
     if has_torch_function(tensors):
-        return handle_torch_function(meshgrid, tensors, *tensors)
+        return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
     if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
         # the old interface of passing the operands as one list argument
         tensors = tensors[0]  # type: ignore[assignment]
-    return _VF.meshgrid(tensors)  # type: ignore[attr-defined]
+
+    # Continue allowing call of old method that takes no indexing
+    # kwarg for forward compatibility reasons.
+    #
+    # Remove this two weeks after landing.
+    kwargs = {} if indexing is None else {'indexing': indexing}
+    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
 
 
 def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
index 70bb828..d823b68 100644 (file)
@@ -2897,7 +2897,12 @@ def baddbmm(g, self, batch1, batch2, beta, alpha):
     return add(g, mul_a, mul_b)
 
 
-def meshgrid(g, tensor_list):
+@parse_args('v', 's')
+def meshgrid(g, tensor_list, indexing: Optional[str] = None):
+    if indexing is None:
+        indexing = 'ij'
+    elif indexing != 'ij':
+        raise ValueError(f'Unsupported indexing: {indexing}')
     tensors = [sym_help._reshape_helper(g, t, g.op("Constant", value_t=torch.LongTensor([-1])))
                for t in sym_help._unpack_list(tensor_list)]
     tensors_shape = [g.op("Shape", t) for t in tensors]
index 8189dce..b8501f9 100644 (file)
@@ -4608,11 +4608,12 @@ def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.d
     ]
 
     sample_inputs = []
-    for shapes in test_cases:
+    for shapes, indexing in itertools.product(test_cases, {'ij'}):
         input, args = make_inputs(
             [make_tensor(shape, device, dtype, requires_grad=requires_grad)
              for shape in shapes])
-        sample_inputs.append(SampleInput(input=input, args=args))
+        sample_inputs.append(SampleInput(input=input, args=args,
+                                         kwargs=dict(indexing=indexing)))
     return sample_inputs
 
 
@@ -7396,9 +7397,7 @@ op_db: List[OpInfo] = [
                DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'))),
     OpInfo('meshgrid',
            variant_test_name='variadic_tensors',
-           # Our implementation corresponds to "ij" indexing for
-           # numpy.meshgrid, but its default value is "xy".
-           ref=lambda *tensors: np.meshgrid(*tensors, indexing='ij'),
+           ref=np.meshgrid,
            dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
            sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'),
            skips=[