[Fix] Raise error when empty index tensor is passed (gather) (#65006)
authorKushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Thu, 16 Sep 2021 17:12:50 +0000 (10:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 17:14:26 +0000 (10:14 -0700)
Summary:
See https://github.com/pytorch/pytorch/pull/63312#issuecomment-919330081 for context.

cc: ezyang ysiraichi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65006

Reviewed By: mruberry

Differential Revision: D30937730

Pulled By: ezyang

fbshipit-source-id: a8f77b1f40d07e7e3bef6caaafa119685f297638

aten/src/ATen/native/ScatterGatherChecks.h
aten/src/ATen/native/TensorAdvancedIndexing.cpp
torch/testing/_internal/common_methods_invocations.py

index 0fc38d5..4518952 100644 (file)
@@ -16,10 +16,12 @@ static void scatter_gather_dtype_check(
   const Tensor& index,
   const c10::optional<Tensor>& src_opt = c10::nullopt
 ) {
-  TORCH_CHECK(
-    index.scalar_type() == at::ScalarType::Long,
-    method_name, "(): Expected dtype int64 for index"
-  );
+  if (index.numel() != 0) {
+    TORCH_CHECK(
+      index.scalar_type() == at::ScalarType::Long,
+      method_name, "(): Expected dtype int64 for index"
+    );
+  }
 
   if (src_opt.has_value()) {
     auto src = src_opt.value();
index 3fb38cc..aecc98c 100644 (file)
@@ -101,10 +101,14 @@ TORCH_META_FUNC(gather)
     at::assert_no_partial_overlap(result, index);
   }
 
-  TORCH_CHECK(
-    index.scalar_type() == at::ScalarType::Long,
-    "gather", "(): Expected dtype int64 for index"
-  );
+  auto is_index_empty = index.numel() == 0;
+  if (!is_index_empty) {
+    TORCH_CHECK(
+      index.scalar_type() == at::ScalarType::Long,
+      "gather", "(): Expected dtype int64 for index"
+    );
+  }
+  if (is_index_empty) return;
   at::native::gather_shape_check(self, wrapped_dim, index);
 }
 
index b8501f9..ab9852f 100644 (file)
@@ -2235,9 +2235,10 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
         SampleInput(
             make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
             args=(0, torch.tensor([0], dtype=torch.int64, device=device))),
+        # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006
         SampleInput(
             make_tensor((S,), device, dtype, low=None, high=None, requires_grad=requires_grad),
-            args=(0, torch.tensor(0, dtype=torch.int64, device=device))),
+            args=(0, torch.tensor([], dtype=torch.uint8, device=device))),
         SampleInput(
             make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
             args=(0, torch.tensor(0, dtype=torch.int64, device=device))),