Ensure that indices are on the same device as self
authorSam Gross <sgross@fb.com>
Thu, 29 Nov 2018 01:51:01 +0000 (17:51 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 29 Nov 2018 01:54:32 +0000 (17:54 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14504

Reviewed By: wat3rBro

Differential Revision: D13242200

Pulled By: colesbury

fbshipit-source-id: 82731cee808681ec612d406342070640eb26e519

aten/src/ATen/native/Indexing.cpp
aten/src/ATen/native/cuda/Loops.cuh
aten/src/ATen/test/basic.cpp

index 163db86..ff9a3c8 100644 (file)
@@ -383,6 +383,12 @@ static AdvancedIndex make_info(Tensor self, TensorList orig) {
   if (!hasContiguousSubspace(indices)) {
     std::tie(self, indices) = transposeToFront(self, indices);
   }
+  // Ensure indices are on the same device as self
+  for (size_t i = 0; i < indices.size(); i++) {
+    if (indices[i].defined() && indices[i].device() != self.device()) {
+      indices[i] = indices[i].to(self.device());
+    }
+  }
   return AdvancedIndex(self, indices);
 }
 
index cd16364..461ed9a 100644 (file)
@@ -59,6 +59,7 @@ static void launch_kernel(int64_t N, const func_t& f) {
   dim3 grid((N + block.x * vt - 1) / (block.x * vt));
   auto stream = at::cuda::getCurrentCUDAStream();
   elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
+  AT_CUDA_CHECK(cudaGetLastError());
 }
 
 template<typename func_t>
index f114e72..b3f5aab 100644 (file)
@@ -265,6 +265,12 @@ void TestIndexingByZerodimTensor() {
   // Throw StartsWith("Can only index with tensors that are scalars (zero-dim)")
   ASSERT_ANY_THROW(tensor[ones({2, 3, 4}, kInt)].equal(one));
 }
+void TestIndexingMixedDevice(Type& type) {
+  Tensor tensor = randn({20, 20}, type);
+  Tensor index = arange(10, kLong).cpu();
+  Tensor result = tensor.index({index});
+  ASSERT_TRUE(result[0].equal(tensor[0]));
+}
 void TestDispatch() {
   Tensor tensor = randn({20, 20});
   Tensor other = randn({20, 20});
@@ -301,6 +307,7 @@ void test(Type& type) {
   TestToString();
   TestIndexingByScalar();
   TestIndexingByZerodimTensor();
+  TestIndexingMixedDevice(type);
   TestDispatch();
   TestCore();
 }