return self.tensor_split(sections, dim);
} else {
auto indices_data = tensor_indices_or_sections.data_ptr<int64_t>();
- std::vector<int64_t> indices(indices_data, indices_data + tensor_indices_or_sections.numel());
+ auto stride = tensor_indices_or_sections.stride(0);
+ auto numel = tensor_indices_or_sections.numel();
+ std::vector<int64_t> indices(numel);
+ for (size_t offset = 0; offset < numel; offset++) {
+ // indices tensor could be non-contiguous
+ indices[offset] = *(indices_data + offset * stride);
+ }
return self.tensor_split(indices, dim);
}
}
(torch.tensor([1, 2, 3]),),
(torch.tensor(1),),
(torch.tensor([1, 2, 3]), 1),
+ (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1),
# Cases with list of indices.
((2, 4),),
((2, 4), 1),
active_if=(IS_MACOS or IS_WINDOWS)),
)),
OpInfo('tensor_split',
+ ref=np.array_split,
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),