Support `torch.concat` alias, add `cat` OpInfo & remove OpInfo test_out skips {cat...
authorAnirudh Dagar <anirudhdagar6@gmail.com>
Tue, 7 Sep 2021 06:55:53 +0000 (23:55 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 7 Sep 2021 06:57:18 +0000 (23:57 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61767

## Changes

- [x] Add `torch.concat` alias to `torch.cat`
- [x] Add OpInfo for `cat`/`concat`
- [x] Fix `test_out` skips (Use `at::native::resize_output` or `at::native::resize_output_check`)
  - [x] `cat`/`concat`
  - [x] `stack`
  - [x] `hstack`
  - [x] `dstack`
  - [x] `vstack`/`row_stack`
- [x] Remove redundant tests for `cat`/`stack`

~I've not added `cat`/`concat` to OpInfo `op_db` yet, since cat is a little more tricky than other OpInfos (should have a lot of tests) and currently there are no OpInfos for that. I can try to add that in a subsequent PR or maybe here itself, whatever is suggested.~
**Edit**: cat/concat OpInfo has been added.

**Note**: I've added the named tensor support for `concat` alias as well, maybe that's out of spec in `array-api` but it is still useful for consistency in PyTorch.

Thanks to krshrimali for guidance on my first PR :))

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi heitorschueroff krshrimali

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62560

Reviewed By: saketh-are

Differential Revision: D30762069

Pulled By: mruberry

fbshipit-source-id: 6985159d1d9756238890488a0ab3ae7699d94337

15 files changed:
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/core/interned_strings.h
aten/src/ATen/native/Resize.cpp
aten/src/ATen/native/Resize.h
aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/cuda/Shape.cu
aten/src/ATen/native/native_functions.yaml
docs/source/torch.rst
test/test_autograd.py
test/test_fx_experimental.py
test/test_tensor_creation_ops.py
torch/_torch_docs.py
torch/csrc/jit/passes/normalize_ops.cpp
torch/overrides.py
torch/testing/_internal/common_methods_invocations.py

index abdf397..6da99df 100644 (file)
@@ -36,7 +36,6 @@ _(aten, _cast_Half) \
 _(aten, _cast_Int) \
 _(aten, _cast_Long) \
 _(aten, _cast_Short) \
-_(aten, _cat) \
 _(aten, _ceil) \
 _(aten, _clamp_max) \
 _(aten, _clamp_min) \
@@ -224,7 +223,6 @@ _(aten, bmm) \
 _(aten, broadcast_tensors) \
 _(aten, broadcast_to) \
 _(aten, cartesian_prod) \
-_(aten, cat) \
 _(aten, cauchy) \
 _(aten, ceil) \
 _(aten, celu) \
index 69e5f97..8d49d82 100644 (file)
@@ -306,6 +306,9 @@ namespace c10 {
   _(aten, bin)                       \
   _(aten, pop)                       \
   _(aten, insert)                    \
+  _(aten, _cat)                      \
+  _(aten, cat)                       \
+  _(aten, concat)                    \
   _(aten, vstack)                    \
   _(aten, row_stack)                 \
   _(prim, unchecked_unwrap_optional) \
index f4bff47..1937a8b 100644 (file)
@@ -8,7 +8,7 @@ namespace at { namespace native {
 
 // Returns true if resize is necessary
 bool resize_output_check(const Tensor& output, IntArrayRef shape) {
-  // Tests for resizing of tensors with one more elements
+  // Tests for resizing of tensors with one or more elements
   if (output.sizes().equals(shape)) {
     return false;
   }
index 5e391a0..6fb52bc 100644 (file)
 namespace at { namespace native {
 
 // TODO: make all operations that resize given outputs use this function
-//   for consistency and maintainability
+//   for consistency and maintainability.
+//   Some operations like `cat` might not be able to make the use of
+//   resize_output directly. For more details to understand how it works in `cat`,
+//   see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
 // Resizes outputs
 // Functions accepting output tensors, like with the "out" kwarg, should
 //   call this function to handle resizing their output tensor.
@@ -20,6 +23,9 @@ namespace at { namespace native {
 // Returns a bool saying whether or not the resize actually happened or not
 TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
 
+// Utility for resize_output
+//  Returns a bool saying resize should happen or not and
+//  raises a warning if resizing for one or more elements
 TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
 
 TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
index edbfa23..8f39786 100644 (file)
@@ -6,7 +6,6 @@
 #include <ATen/NamedTensorUtils.h>
 #include <ATen/core/DimVector.h>
 #include <ATen/native/Copy.h>
-#include <ATen/native/cpu/CatKernel.h>
 #include <ATen/native/Resize.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/TypeProperties.h>
@@ -193,7 +192,10 @@ Tensor & _cat_out_cpu(TensorList tensors, int64_t dim, Tensor& result) {
   result_size[dim] = cat_dim_size;
 
   // skip resizing if size of result is same as expected
-  if (result.sizes() != result_size) {
+  // raise a warning while resizing if output has one or more elements
+  // See https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
+  // for understanding why at::native::resize_output is not called directly.
+  if (at::native::resize_output_check(result, result_size)) {
     result.resize_(result_size, first_tensor_mem_format);
   }
 
@@ -301,6 +303,23 @@ Tensor cat(TensorList tensors, Dimname dim) {
   return at::cat(tensors, dimname_to_position(tensors[0], dim));
 }
 
+// torch.concat, alias for torch.cat
+Tensor& concat_out(TensorList tensors, Dimname dim, Tensor& result) {
+  return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
+}
+
+Tensor concat(TensorList tensors, Dimname dim) {
+  return at::cat(tensors, dimname_to_position(tensors[0], dim));
+}
+
+Tensor & concat_out(TensorList tensors, int64_t dim, Tensor & result) {
+  return at::cat_out(result, tensors, dim);
+}
+
+Tensor concat(TensorList tensors, int64_t dim) {
+  return at::cat(tensors, dim);
+}
+
 static bool sizes_match_except(IntArrayRef s1, IntArrayRef s2, int64_t dim_except /* should already be wrapped */) {
   if (s1.size() != s2.size()) {
     return false;
@@ -1497,9 +1516,8 @@ bool inline maybe_native_stack(Tensor& result, TensorList tensors, int64_t dim)
     result_sizes.insert(result_sizes.begin() + dim, tensors.size());
 
     // skip resizing if size of result is same as expected
-    if (result.sizes() != result_sizes) {
-      result.resize_(result_sizes);
-    }
+    // raise a warning while resizing if output has one or more elements
+    at::native::resize_output(result, result_sizes);
     stack_serial_stub(kCPU, result, tensors, dim);
     return true;
   }
index dec9854..05fa4c6 100644 (file)
@@ -2,6 +2,7 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/MemoryOverlap.h>
 #include <ATen/cuda/detail/IndexUtils.cuh>
+#include <ATen/native/Resize.h>
 #include <ATen/native/TypeProperties.h>
 #include <ATen/Dispatch.h>
 #include <c10/core/MemoryFormat.h>
@@ -528,7 +529,10 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
   size[dimension] = cat_dim_size;
 
   // skip resizing if size of result is same as expected
-  if (out.sizes() != size) {
+  // raise a warning while resizing if output has one or more elements
+  // See https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
+  // for understanding why at::native::resize_output is not called directly.
+  if (at::native::resize_output_check(out, size)) {
     out.resize_(size, memory_format);
   }
 
index ca13e05..3a1f75c 100644 (file)
 
 - func: cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
 
+# alias for torch.cat
+- func: concat(Tensor[] tensors, int dim=0) -> Tensor
+
+- func: concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: concat.names(Tensor[] tensors, Dimname dim) -> Tensor
+
+- func: concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+
 - func: block_diag(Tensor[] tensors) -> Tensor
   variants: function
 
index 88cbc69..5aa5dbc 100644 (file)
@@ -88,6 +88,7 @@ Indexing, Slicing, Joining, Mutating Ops
     :nosignatures:
 
     cat
+    concat
     conj
     chunk
     dsplit
index 2da74cb..61a46b4 100644 (file)
@@ -2735,36 +2735,6 @@ class TestAutograd(TestCase):
                               lambda a, b, c: torch.block_diag(a, b, c),
                               True, f_args_variable, f_args_tensor)
 
-    def test_cat(self):
-        f_args_variable = (torch.randn(1, S, S, dtype=torch.double, requires_grad=True),
-                           torch.randn(2, S, S, dtype=torch.double, requires_grad=True),
-                           torch.randn(3, S, S, dtype=torch.double, requires_grad=True),
-                           0)
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_cat", "cat",
-                              lambda a, b, c, dim: torch.cat((a, b, c), dim),
-                              True, f_args_variable, f_args_tensor, check_forward_ad=True)
-
-    def test_cat_negdim_1(self):
-        f_args_variable = (torch.randn(S, S, 1, dtype=torch.double, requires_grad=True),
-                           torch.randn(S, S, 2, dtype=torch.double, requires_grad=True),
-                           torch.randn(S, S, 3, dtype=torch.double, requires_grad=True),
-                           -1)
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_cat_negdim_1", "cat",
-                              lambda a, b, c, dim: torch.cat((a, b, c), dim),
-                              True, f_args_variable, f_args_tensor, check_forward_ad=True)
-
-    def test_cat_negdim_2(self):
-        f_args_variable = (torch.randn(S, 1, S, dtype=torch.double, requires_grad=True),
-                           torch.randn(S, 2, S, dtype=torch.double, requires_grad=True),
-                           torch.randn(S, 3, S, dtype=torch.double, requires_grad=True),
-                           -2)
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_cat_negdim_2", "cat",
-                              lambda a, b, c, dim: torch.cat((a, b, c), dim),
-                              True, f_args_variable, f_args_tensor, check_forward_ad=True)
-
     def test_cat_empty_legacy(self):
         f_args_variable = (torch.randn(0, dtype=torch.double, requires_grad=True),
                            torch.randn(S, S, dtype=torch.double, requires_grad=True))
@@ -2776,14 +2746,6 @@ class TestAutograd(TestCase):
                               False, f_args_variable, f_args_tensor, check_forward_ad=True)
         self.assertTrue(gradcheck(lambda a, b: torch.cat((a, b)), f_args_variable, eps=1e-6, atol=PRECISION))
 
-    def test_cat_empty(self):
-        f_args_variable = (torch.randn(0, S, dtype=torch.double, requires_grad=True),
-                           torch.randn(S, S, dtype=torch.double, requires_grad=True))
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_cat_empty", "cat",
-                              lambda a, b: torch.cat((a, b)),
-                              True, f_args_variable, f_args_tensor, check_forward_ad=True)
-
     def test_var_mean_differentiable(self):
         dim = [2, 4]
         keepdim = False
index e723ee4..fc90f49 100644 (file)
@@ -1497,7 +1497,7 @@ class TestNormalizeOperators(JitTestCase):
             return
 
         # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
-        fx_fail = {"stack", "hstack", "vstack", "dstack", "linalg.multi_dot"}
+        fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot"}
         sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
         for sample_input in sample_inputs_itr:
             unsupported_arg_type = False
index dcb4938..a749691 100644 (file)
@@ -695,6 +695,47 @@ class TestTensorCreation(TestCase):
         self.assertEqual(res1, res2)
         self.assertTrue(res1.is_contiguous(memory_format=torch.channels_last))
 
+    @onlyCUDA
+    def test_cat_out_memory_format(self, device):
+        inp_size = (4, 4, 4, 4)
+        expected_size = (8, 4, 4, 4)
+        a_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last)
+        a_cpu = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.channels_last)
+        b_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.contiguous_format)
+        b_cpu = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.contiguous_format)
+        c_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last)
+
+        # Case 1: if out= is the correct shape then the memory format of out= is respected
+
+        out_cuda = torch.empty(expected_size, device=device).contiguous(memory_format=torch.contiguous_format)
+        res1_cuda = torch.cat((a_cuda, b_cuda), out=out_cuda)
+
+        out_cpu = torch.empty(expected_size, device='cpu').contiguous(memory_format=torch.contiguous_format)
+        res1_cpu = torch.cat((a_cpu, b_cpu), out=out_cpu)
+
+        self.assertTrue(res1_cuda.is_contiguous(memory_format=torch.contiguous_format))
+        self.assertTrue(res1_cpu.is_contiguous(memory_format=torch.contiguous_format))
+
+        # Case 2: if out= is not the correct shape then the output it is resized internally
+        # - For the CPU variant the memory format is that of the first tensor
+        # - For the CUDA variant it only propagates memory format if all the tensors have
+        #   the same memory format, otherwise it just uses contiguous_format as a default
+
+        out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format)
+        # a_cuda and b_cuda have different memory_format
+        res2_cuda = torch.cat((a_cuda, b_cuda), out=out_cuda)
+
+        out_cpu = torch.empty((0), device='cpu').contiguous(memory_format=torch.contiguous_format)
+        res2_cpu = torch.cat((a_cpu, b_cpu), out=out_cpu)
+
+        self.assertTrue(res2_cuda.is_contiguous(memory_format=torch.contiguous_format))
+        self.assertTrue(res2_cpu.is_contiguous(memory_format=torch.channels_last))
+
+        out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format)
+        # a_cuda and c_cuda have same memory_format
+        res3_cuda = torch.cat((a_cuda, c_cuda), out=out_cuda)
+
+        self.assertTrue(res3_cuda.is_contiguous(memory_format=torch.channels_last))
 
     @onlyCUDA
     @deviceCountAtLeast(2)
@@ -713,8 +754,8 @@ class TestTensorCreation(TestCase):
     def test_cat_stack_cross_devices(self, device):
         cuda = torch.randn((3, 3), device=device)
         cpu = torch.randn((3, 3), device='cpu')
-        out_cpu = cpu.clone()
-        out_cuda = cuda.clone()
+
+        # cat
         with self.assertRaisesRegex(RuntimeError,
                                     "Expected all tensors to be on the same device"):
             torch.cat((cuda, cpu))
@@ -722,18 +763,6 @@ class TestTensorCreation(TestCase):
                                     "Expected all tensors to be on the same device"):
             torch.cat((cpu, cuda))
 
-        with self.assertRaisesRegex(RuntimeError,
-                                    "Expected all tensors to be on the same device"):
-            torch.cat((cpu, cuda), out=out_cuda)
-
-        with self.assertRaisesRegex(RuntimeError,
-                                    "Expected all tensors to be on the same device"):
-            torch.cat((cpu, cpu), out=out_cuda)
-
-        with self.assertRaisesRegex(RuntimeError,
-                                    "Expected all tensors to be on the same device"):
-            torch.cat((cuda, cuda), out=out_cpu)
-
         # Stack
         with self.assertRaisesRegex(RuntimeError,
                                     "Expected all tensors to be on the same device"):
@@ -742,18 +771,6 @@ class TestTensorCreation(TestCase):
                                     "Expected all tensors to be on the same device"):
             torch.stack((cpu, cuda))
 
-        with self.assertRaisesRegex(RuntimeError,
-                                    "Expected all tensors to be on the same device"):
-            torch.stack((cpu, cuda), out=out_cuda)
-
-        with self.assertRaisesRegex(RuntimeError,
-                                    "Expected all tensors to be on the same device"):
-            torch.stack((cpu, cpu), out=out_cuda)
-
-        with self.assertRaisesRegex(RuntimeError,
-                                    "Expected all tensors to be on the same device"):
-            torch.stack((cuda, cuda), out=out_cpu)
-
     # TODO: reconcile with other cat tests
     # TODO: Compare with a NumPy reference instead of CPU
     @onlyCUDA
index bbb8d98..7dca8a7 100644 (file)
@@ -1856,6 +1856,13 @@ Example::
              -0.5790,  0.1497]])
 """.format(**common_args))
 
+add_docstr(torch.concat,
+           r"""
+concat(tensors, dim=0, *, out=None) -> Tensor
+
+Alias of :func:`torch.cat`.
+""")
+
 add_docstr(torch.ceil,
            r"""
 ceil(input, *, out=None) -> Tensor
index cc6444e..5ac36e1 100644 (file)
@@ -104,6 +104,7 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
       {aten::multiply_, aten::mul_},
       {aten::true_divide, aten::div},
       {aten::true_divide_, aten::div_},
+      {aten::concat, aten::cat},
       {aten::row_stack, aten::vstack},
       {aten::swapdims, aten::transpose},
       {aten::swapdims_, aten::transpose_},
index 64b18b8..aca14a6 100644 (file)
@@ -360,6 +360,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
         torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
         torch.cartesian_prod: lambda *tensors: -1,
         torch.cat: lambda tensors, dim=0, out=None: -1,
+        torch.concat: lambda tensors, dim=0, out=None: -1,  # alias for torch.cat
         torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
         torch.ceil: lambda input, out=None: -1,
         torch.celu: lambda input, alhpa=1., inplace=False: -1,
index 5e009ee..ace4fa1 100644 (file)
@@ -2180,6 +2180,25 @@ def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs):
 
     return (SampleInput(tensors, args=(0,)),)
 
+def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: Tuple[tuple, tuple, dict] = (  # type: ignore[assignment]
+        ((S, S), (S, S), {'dim': -1}),
+        ((S, S), (S, S), {'dim': 1}),
+        ((M, S), (S, S), {'dim': 0}),  # different shapes
+        ((1, 2, 3), (1, 2, 3), {'dim': -2}),
+        ((0,), (0,), {'dim': 0}),  # empty tensor
+        ((0, S), (S, S), {'dim': 0}),
+        ((1,), (1,), {})  # dim not passed, fallback to default
+    )
+
+    def generator():
+        for input_shape1, input_shape2, kwargs in cases:
+            yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs)
+
+    return list(generator())
+
 def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs):
     tensors = [
         make_tensor((S, S), device, dtype, requires_grad=requires_grad),
@@ -8582,17 +8601,11 @@ op_db: List[OpInfo] = [
     OpInfo('stack',
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_stack,
-           assert_autodiffed=True,
-           skips=(
-               # stack does not correctly warn when resizing out= inputs
-               SkipInfo('TestCommon', 'test_out'),),),
+           assert_autodiffed=True),
     OpInfo('hstack',
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_hstack_dstack_vstack,
-           supports_forward_ad=True,
-           skips=(
-               # hstack does not correctly warn when resizing out= inputs
-               SkipInfo('TestCommon', 'test_out'),),),
+           supports_forward_ad=True),
     OpInfo('hypot',
            dtypes=floating_types(),
            dtypesIfCPU=floating_types_and(torch.bfloat16),
@@ -8609,24 +8622,31 @@ op_db: List[OpInfo] = [
                # JIT tests don't work with Tensor keyword arguments
                # https://github.com/pytorch/pytorch/issues/58507
                SkipInfo('TestJit', 'test_variant_consistency_jit'),),),
+    OpInfo('cat',
+           ref=lambda input_seq, dim=0, **kwargs: np.concatenate(input_seq, axis=dim, **kwargs),
+           aliases=('concat',),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_cat_concat,
+           supports_forward_ad=True,
+           assert_autodiffed=True,
+           skips=(
+               # RuntimeError: Arguments for call not valid.
+               #               Expected a value of type 'List[Tensor]' for argument
+               #               'tensors' but instead found type 'Tensor (inferred)'.
+               SkipInfo('TestJit', 'test_jit_alias_remapping'),)),
     OpInfo('vstack',
            aliases=('row_stack',),
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_hstack_dstack_vstack,
            supports_forward_ad=True,
            skips=(
-               # vstack does not correctly warn when resizing out= inputs
-               SkipInfo('TestCommon', 'test_out'),
                # RuntimeError: _fn() Expected a value of type
                #   'Tensor (inferred)' for argument 't0' but instead found type 'tuple'.
-               SkipInfo('TestJit', 'test_jit_alias_remapping'))),
+               SkipInfo('TestJit', 'test_jit_alias_remapping'),)),
     OpInfo('dstack',
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_hstack_dstack_vstack,
-           supports_forward_ad=True,
-           skips=(
-               # dstack does not correctly warn when resizing out= inputs
-               SkipInfo('TestCommon', 'test_out'),)),
+           supports_forward_ad=True),
     OpInfo('unfold',
            op=lambda x, *args: x.unfold(*args),
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),