From 533668d7e482ec5a88aaa274027de9cbefad5632 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Wed, 28 Nov 2018 17:51:01 -0800 Subject: [PATCH] Ensure that indices are on the same device as self Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14504 Reviewed By: wat3rBro Differential Revision: D13242200 Pulled By: colesbury fbshipit-source-id: 82731cee808681ec612d406342070640eb26e519 --- aten/src/ATen/native/Indexing.cpp | 6 ++++++ aten/src/ATen/native/cuda/Loops.cuh | 1 + aten/src/ATen/test/basic.cpp | 7 +++++++ 3 files changed, 14 insertions(+) diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index 163db86..ff9a3c8 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -383,6 +383,12 @@ static AdvancedIndex make_info(Tensor self, TensorList orig) { if (!hasContiguousSubspace(indices)) { std::tie(self, indices) = transposeToFront(self, indices); } + // Ensure indices are on the same device as self + for (size_t i = 0; i < indices.size(); i++) { + if (indices[i].defined() && indices[i].device() != self.device()) { + indices[i] = indices[i].to(self.device()); + } + } return AdvancedIndex(self, indices); } diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index cd16364..461ed9a 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -59,6 +59,7 @@ static void launch_kernel(int64_t N, const func_t& f) { dim3 grid((N + block.x * vt - 1) / (block.x * vt)); auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel<<>>(N, f); + AT_CUDA_CHECK(cudaGetLastError()); } template diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index f114e72..b3f5aab 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -265,6 +265,12 @@ void TestIndexingByZerodimTensor() { // Throw StartsWith("Can only index with tensors that are scalars (zero-dim)") ASSERT_ANY_THROW(tensor[ones({2, 3, 4}, kInt)].equal(one)); } +void TestIndexingMixedDevice(Type& type) { + Tensor tensor = randn({20, 20}, type); + Tensor index = arange(10, kLong).cpu(); + Tensor result = tensor.index({index}); + ASSERT_TRUE(result[0].equal(tensor[0])); +} void TestDispatch() { Tensor tensor = randn({20, 20}); Tensor other = randn({20, 20}); @@ -301,6 +307,7 @@ void test(Type& type) { TestToString(); TestIndexingByScalar(); TestIndexingByZerodimTensor(); + TestIndexingMixedDevice(type); TestDispatch(); TestCore(); } -- 2.7.4