[fix] tensor_split : non-contiguous indices tensor (#63390)
authorkshitij12345 <kshitijkalambarkar@gmail.com>
Wed, 18 Aug 2021 23:08:48 +0000 (16:08 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 18 Aug 2021 23:10:17 +0000 (16:10 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63281

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63390

Reviewed By: ejguan

Differential Revision: D30362649

Pulled By: mruberry

fbshipit-source-id: 3ea3ad02199e4345beb0b580d056babd56112309

aten/src/ATen/native/TensorShape.cpp
torch/testing/_internal/common_methods_invocations.py

index e915078..2545ec4 100644 (file)
@@ -609,7 +609,13 @@ std::vector<Tensor> tensor_split(const Tensor& self, const Tensor& tensor_indice
     return self.tensor_split(sections, dim);
   } else {
     auto indices_data = tensor_indices_or_sections.data_ptr<int64_t>();
-    std::vector<int64_t> indices(indices_data, indices_data + tensor_indices_or_sections.numel());
+    auto stride = tensor_indices_or_sections.stride(0);
+    auto numel = tensor_indices_or_sections.numel();
+    std::vector<int64_t> indices(numel);
+    for (size_t offset = 0; offset < numel; offset++) {
+      // indices tensor could be non-contiguous
+      indices[offset] = *(indices_data + offset * stride);
+    }
     return self.tensor_split(indices, dim);
   }
 }
index 5d55f0e..7e57d5d 100644 (file)
@@ -859,6 +859,7 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs):
         (torch.tensor([1, 2, 3]),),
         (torch.tensor(1),),
         (torch.tensor([1, 2, 3]), 1),
+        (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1),
         # Cases with list of indices.
         ((2, 4),),
         ((2, 4), 1),
@@ -7590,6 +7591,7 @@ op_db: List[OpInfo] = [
                                 active_if=(IS_MACOS or IS_WINDOWS)),
                    )),
     OpInfo('tensor_split',
+           ref=np.array_split,
            dtypes=all_types_and_complex_and(torch.bool),
            dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),