#include <ATen/native/Resize.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TypeProperties.h>
+#include <ATen/native/TensorShape.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/native/cpu/StackKernel.h>
#include <ATen/NativeFunctions.h>
return expand_outplace(tensors);
}
-// Check to see if the shape of tensors is compatible
-// for being concatenated along a given dimension.
-static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
- int64_t first_dims = first.dim();
- int64_t second_dims = second.dim();
- TORCH_CHECK(first_dims == second_dims, "torch.cat(): Tensors must have same number of dimensions: got ",
- first_dims, " and ", second_dims);
- for (int64_t dim = 0; dim < first_dims; dim++) {
- if (dim == dimension) {
- continue;
- }
- int64_t first_dim_size = first.sizes()[dim];
- int64_t second_dim_size = second.sizes()[dim];
- TORCH_CHECK(first_dim_size == second_dim_size, "torch.cat(): Sizes of tensors must match except in dimension ",
- dimension, ". Got ", first_dim_size, " and ", second_dim_size, " in dimension ", dim,
- " (The offending index is ", index, ")");
- }
-}
-
static bool should_skip(const Tensor& t) {
return t.numel() == 0 && t.dim() == 1;
}
--- /dev/null
+#include <ATen/ATen.h>
+
+namespace at {
+namespace native {
+
+ // Check to see if the shape of tensors is compatible
+ // for being concatenated along a given dimension.
+inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
+ int64_t first_dims = first.dim();
+ int64_t second_dims = second.dim();
+ TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
+ first_dims, " and ", second_dims);
+ for (int64_t dim = 0; dim < first_dims; dim++) {
+ if (dim == dimension) {
+ continue;
+ }
+ int64_t first_dim_size = first.sizes()[dim];
+ int64_t second_dim_size = second.sizes()[dim];
+ TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
+ dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
+ }
+ }
+
+}} // namespace at::native
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/Resize.h>
#include <ATen/native/TypeProperties.h>
+#include <ATen/native/TensorShape.h>
#include <ATen/Dispatch.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Optional.h>
}
}
-void check_shape_except_dim(const Tensor &first, const Tensor &second,
- int dimension, int index)
-{
- int first_dims = first.dim();
- int second_dims = second.dim();
- TORCH_CHECK(first_dims == second_dims,
- "Tensors must have same number of dimensions: got ", first_dims,
- " and ", second_dims);
- for (int dim = 0; dim < first_dims; dim++) {
- if (dim == dimension) {
- continue;
- }
- int64_t first_dim_size = at::native::size(first, dim);
- int64_t second_dim_size = at::native::size(second, dim);
- TORCH_CHECK(first_dim_size == second_dim_size,
- "Sizes of tensors must match except in dimension ", dim, ". Got ",
- static_cast<long long>(first_dim_size), " and ",
- static_cast<long long>(second_dim_size), " (The offending index is ",
- index, ")");
- }
-}
-
template <typename scalar_t>
void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
}
nDims = inputs[i].dim();
notSkippedTensor = &inputs[i];
+ break;
}
// If all inputs are empty tensors, return an empty tensor
if (should_skip(tensor)) {
continue;
}
- check_shape_except_dim(*notSkippedTensor, tensor, dimension, i);
+ check_cat_shape_except_dim(*notSkippedTensor, tensor, dimension, i);
cat_dim_size += at::native::size(tensor, dimension);
}