Reverts cat and stack warning when out= is not the expected shape (#64714)
authorMike Ruberry <mruberry@fb.com>
Thu, 9 Sep 2021 17:02:03 +0000 (10:02 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 17:03:22 +0000 (10:03 -0700)
Summary:
These warnings are being thrown too aggressively at the moment. See https://github.com/pytorch/pytorch/issues/64709 for a follow-up to reenable them once internal call sites are reviewed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64714

Reviewed By: ngimel

Differential Revision: D30822965

Pulled By: mruberry

fbshipit-source-id: 3ad7c92d381d42ac6187ed84afab477c579a8f35

aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/cuda/Shape.cu
torch/testing/_internal/common_methods_invocations.py

index 8f39786..6fea912 100644 (file)
@@ -195,7 +195,10 @@ Tensor & _cat_out_cpu(TensorList tensors, int64_t dim, Tensor& result) {
   // raise a warning while resizing if output has one or more elements
   // See https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
   // for understanding why at::native::resize_output is not called directly.
-  if (at::native::resize_output_check(result, result_size)) {
+  // if (at::native::resize_output_check(result, result_size)) {
+  // TODO: restore the above, see https://github.com/pytorch/pytorch/issues/64709
+
+  if (result.sizes() != result_size) {
     result.resize_(result_size, first_tensor_mem_format);
   }
 
@@ -1517,7 +1520,13 @@ bool inline maybe_native_stack(Tensor& result, TensorList tensors, int64_t dim)
 
     // skip resizing if size of result is same as expected
     // raise a warning while resizing if output has one or more elements
-    at::native::resize_output(result, result_sizes);
+    // at::native::resize_output(result, result_sizes);
+    // TODO: restore the above, see https://github.com/pytorch/pytorch/issues/64709
+
+    if (result.sizes() != result_sizes) {
+      result.resize_(result_sizes);
+    }
+
     stack_serial_stub(kCPU, result, tensors, dim);
     return true;
   }
index 05fa4c6..aec9531 100644 (file)
@@ -532,7 +532,10 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
   // raise a warning while resizing if output has one or more elements
   // See https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
   // for understanding why at::native::resize_output is not called directly.
-  if (at::native::resize_output_check(out, size)) {
+  // if (at::native::resize_output_check(out, size)) {
+  // TODO: restore the above, see https://github.com/pytorch/pytorch/issues/64709
+
+  if (out.sizes() != size) {
     out.resize_(size, memory_format);
   }
 
index 40ae6b4..f0c163c 100644 (file)
@@ -8683,11 +8683,19 @@ op_db: List[OpInfo] = [
     OpInfo('stack',
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_stack,
-           assert_autodiffed=True),
+           assert_autodiffed=True,
+           skips=(
+               # TODO: see https://github.com/pytorch/pytorch/issues/64709
+               SkipInfo('TestCommon', 'test_out'),
+           )),
     OpInfo('hstack',
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_hstack_dstack_vstack,
-           supports_forward_ad=True),
+           supports_forward_ad=True,
+           skips=(
+               # TODO: see https://github.com/pytorch/pytorch/issues/64709
+               SkipInfo('TestCommon', 'test_out'),
+           )),
     OpInfo('hypot',
            dtypes=floating_types(),
            dtypesIfCPU=floating_types_and(torch.bfloat16),
@@ -8712,6 +8720,8 @@ op_db: List[OpInfo] = [
            supports_forward_ad=True,
            assert_autodiffed=True,
            skips=(
+               # TODO: see https://github.com/pytorch/pytorch/issues/64709
+               SkipInfo('TestCommon', 'test_out'),
                # RuntimeError: Arguments for call not valid.
                #               Expected a value of type 'List[Tensor]' for argument
                #               'tensors' but instead found type 'Tensor (inferred)'.
@@ -8722,13 +8732,19 @@ op_db: List[OpInfo] = [
            sample_inputs_func=sample_inputs_hstack_dstack_vstack,
            supports_forward_ad=True,
            skips=(
+               # TODO: see https://github.com/pytorch/pytorch/issues/64709
+               SkipInfo('TestCommon', 'test_out'),
                # RuntimeError: _fn() Expected a value of type
                #   'Tensor (inferred)' for argument 't0' but instead found type 'tuple'.
                SkipInfo('TestJit', 'test_jit_alias_remapping'),)),
     OpInfo('dstack',
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_hstack_dstack_vstack,
-           supports_forward_ad=True),
+           supports_forward_ad=True,
+           skips=(
+               # TODO: see https://github.com/pytorch/pytorch/issues/64709
+               SkipInfo('TestCommon', 'test_out'),
+           )),
     OpInfo('unfold',
            op=lambda x, *args: x.unfold(*args),
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),