cat_shape_check: Fixes dimension in the error message for CUDA cat shape check and...
authorPalwisha Akhtar <msoyturk20@ku.edu.tr>
Thu, 9 Sep 2021 19:49:03 +0000 (12:49 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 19:51:11 +0000 (12:51 -0700)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/64207

Thank you, SsnL for providing the reproducing script.

cc ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64556

Reviewed By: albanD

Differential Revision: D30843859

Pulled By: ngimel

fbshipit-source-id: 457ebe80eaef793d9f5d35ee962b6697e5de1907

aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/TensorShape.h [new file with mode: 0644]
aten/src/ATen/native/cuda/Shape.cu

index 6fea912..0d2fdf3 100644 (file)
@@ -9,6 +9,7 @@
 #include <ATen/native/Resize.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/TypeProperties.h>
+#include <ATen/native/TensorShape.h>
 #include <ATen/native/cpu/CatKernel.h>
 #include <ATen/native/cpu/StackKernel.h>
 #include <ATen/NativeFunctions.h>
@@ -93,25 +94,6 @@ std::vector<Tensor> broadcast_tensors(TensorList tensors) {
   return expand_outplace(tensors);
 }
 
-// Check to see if the shape of tensors is compatible
-// for being concatenated along a given dimension.
-static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
-  int64_t first_dims = first.dim();
-  int64_t second_dims = second.dim();
-  TORCH_CHECK(first_dims == second_dims, "torch.cat(): Tensors must have same number of dimensions: got ",
-              first_dims, " and ", second_dims);
-  for (int64_t dim = 0; dim < first_dims; dim++) {
-    if (dim == dimension) {
-      continue;
-    }
-    int64_t first_dim_size = first.sizes()[dim];
-    int64_t second_dim_size = second.sizes()[dim];
-    TORCH_CHECK(first_dim_size == second_dim_size, "torch.cat(): Sizes of tensors must match except in dimension ",
-                dimension, ". Got ", first_dim_size, " and ", second_dim_size, " in dimension ", dim,
-                " (The offending index is ", index, ")");
-  }
-}
-
 static bool should_skip(const Tensor& t) {
   return t.numel() == 0 && t.dim() == 1;
 }
diff --git a/aten/src/ATen/native/TensorShape.h b/aten/src/ATen/native/TensorShape.h
new file mode 100644 (file)
index 0000000..9d5db6d
--- /dev/null
@@ -0,0 +1,24 @@
+#include <ATen/ATen.h>
+
+namespace at {
+namespace native {
+
+ // Check to see if the shape of tensors is compatible
+ // for being concatenated along a given dimension.
+inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
+   int64_t first_dims = first.dim();
+   int64_t second_dims = second.dim();
+   TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
+               first_dims, " and ", second_dims);
+   for (int64_t dim = 0; dim < first_dims; dim++) {
+     if (dim == dimension) {
+       continue;
+     }
+     int64_t first_dim_size = first.sizes()[dim];
+     int64_t second_dim_size = second.sizes()[dim];
+     TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
+                 dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
+   }
+ }
+
+}} // namespace at::native
index aec9531..1fd151c 100644 (file)
@@ -4,6 +4,7 @@
 #include <ATen/cuda/detail/IndexUtils.cuh>
 #include <ATen/native/Resize.h>
 #include <ATen/native/TypeProperties.h>
+#include <ATen/native/TensorShape.h>
 #include <ATen/Dispatch.h>
 #include <c10/core/MemoryFormat.h>
 #include <c10/util/Optional.h>
@@ -172,28 +173,6 @@ __global__ void CatArrayBatchedCopy(
     }
 }
 
-void check_shape_except_dim(const Tensor &first, const Tensor &second,
-                            int dimension, int index)
-{
-  int first_dims = first.dim();
-  int second_dims = second.dim();
-  TORCH_CHECK(first_dims == second_dims,
-      "Tensors must have same number of dimensions: got ", first_dims,
-      " and ", second_dims);
-  for (int dim = 0; dim < first_dims; dim++) {
-    if (dim == dimension) {
-      continue;
-    }
-    int64_t first_dim_size = at::native::size(first, dim);
-    int64_t second_dim_size = at::native::size(second, dim);
-    TORCH_CHECK(first_dim_size == second_dim_size,
-        "Sizes of tensors must match except in dimension ", dim, ". Got ",
-        static_cast<long long>(first_dim_size), " and ",
-        static_cast<long long>(second_dim_size), " (The offending index is ",
-        index, ")");
-  }
-}
-
 template <typename scalar_t>
 void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
                   int nDims, c10::MemoryFormat memory_format) {
@@ -489,6 +468,7 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
     }
     nDims = inputs[i].dim();
     notSkippedTensor = &inputs[i];
+    break;
   }
 
   // If all inputs are empty tensors, return an empty tensor
@@ -521,7 +501,7 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
     if (should_skip(tensor)) {
       continue;
     }
-    check_shape_except_dim(*notSkippedTensor, tensor, dimension, i);
+    check_cat_shape_except_dim(*notSkippedTensor, tensor, dimension, i);
     cat_dim_size += at::native::size(tensor, dimension);
   }