const Tensor& index,
const c10::optional<Tensor>& src_opt = c10::nullopt
) {
- TORCH_CHECK(
- index.scalar_type() == at::ScalarType::Long,
- method_name, "(): Expected dtype int64 for index"
- );
+ if (index.numel() != 0) {
+ TORCH_CHECK(
+ index.scalar_type() == at::ScalarType::Long,
+ method_name, "(): Expected dtype int64 for index"
+ );
+ }
if (src_opt.has_value()) {
auto src = src_opt.value();
at::assert_no_partial_overlap(result, index);
}
- TORCH_CHECK(
- index.scalar_type() == at::ScalarType::Long,
- "gather", "(): Expected dtype int64 for index"
- );
+ auto is_index_empty = index.numel() == 0;
+ if (!is_index_empty) {
+ TORCH_CHECK(
+ index.scalar_type() == at::ScalarType::Long,
+ "gather", "(): Expected dtype int64 for index"
+ );
+ }
+ if (is_index_empty) return;
at::native::gather_shape_check(self, wrapped_dim, index);
}
SampleInput(
make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(0, torch.tensor([0], dtype=torch.int64, device=device))),
+ # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006
SampleInput(
make_tensor((S,), device, dtype, low=None, high=None, requires_grad=requires_grad),
- args=(0, torch.tensor(0, dtype=torch.int64, device=device))),
+ args=(0, torch.tensor([], dtype=torch.uint8, device=device))),
SampleInput(
make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(0, torch.tensor(0, dtype=torch.int64, device=device))),