From 8c9caf185be66b996df5caf4b37a0e902ae3ca41 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Fri, 5 Apr 2019 18:13:39 -0700 Subject: [PATCH] Add numpy like repeat as torch.repeat_interleave (#18395) Summary: Fixes: https://github.com/pytorch/pytorch/issues/14093 cc: SsnL Pull Request resolved: https://github.com/pytorch/pytorch/pull/18395 Differential Revision: D14599509 Pulled By: umanwizard fbshipit-source-id: 2391a1cc135fe5bab38475f1c8ed87c4a96222f3 --- aten/src/ATen/core/Tensor.h | 2 ++ aten/src/ATen/core/TensorMethods.h | 6 ++++ aten/src/ATen/core/Type.h | 3 ++ aten/src/ATen/native/Repeat.cpp | 47 ++++++++++++++++++++++++++++++ aten/src/ATen/native/Repeat.h | 23 +++++++++++++++ aten/src/ATen/native/cuda/Repeat.cu | 29 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 15 ++++++++++ docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_torch.py | 41 ++++++++++++++++++++++++++ torch/_tensor_docs.py | 8 +++++ torch/_torch_docs.py | 46 +++++++++++++++++++++++++++++ 12 files changed, 222 insertions(+) create mode 100644 aten/src/ATen/native/Repeat.cpp create mode 100644 aten/src/ATen/native/Repeat.h create mode 100644 aten/src/ATen/native/cuda/Repeat.cu diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index a171eab..e00e6ff 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -466,6 +466,8 @@ class CAFFE2_API Tensor { Tensor pin_memory() const; Tensor pinverse(double rcond=1e-15) const; Tensor repeat(IntArrayRef repeats) const; + Tensor repeat_interleave(const Tensor & repeats, c10::optional dim=c10::nullopt) const; + Tensor repeat_interleave(int64_t repeats, c10::optional dim=c10::nullopt) const; Tensor reshape(IntArrayRef shape) const; Tensor reshape_as(const Tensor & other) const; Tensor round() const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index efe387d..4096f15 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -454,6 +454,12 @@ inline Tensor Tensor::pinverse(double rcond) const { inline Tensor Tensor::repeat(IntArrayRef repeats) const { return dispatch_type().repeat(*this, repeats); } +inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional dim) const { + return dispatch_type().repeat_interleave(*this, repeats, dim); +} +inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional dim) const { + return dispatch_type().repeat_interleave(*this, repeats, dim); +} inline Tensor Tensor::reshape(IntArrayRef shape) const { return dispatch_type().reshape(*this, shape); } diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 9a1d6fc..0eea7bb 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -346,6 +346,9 @@ struct CAFFE2_API Type { virtual Tensor pin_memory(const Tensor & self) const = 0; virtual Tensor pinverse(const Tensor & self, double rcond) const = 0; virtual Tensor repeat(const Tensor & self, IntArrayRef repeats) const = 0; + virtual Tensor repeat_interleave(const Tensor & repeats) const = 0; + virtual Tensor repeat_interleave(const Tensor & self, const Tensor & repeats, c10::optional dim) const = 0; + virtual Tensor repeat_interleave(const Tensor & self, int64_t repeats, c10::optional dim) const = 0; virtual Tensor reshape(const Tensor & self, IntArrayRef shape) const = 0; virtual Tensor reshape_as(const Tensor & self, const Tensor & other) const = 0; virtual Tensor round(const Tensor & self) const = 0; diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp new file mode 100644 index 0000000..0137c80 --- /dev/null +++ b/aten/src/ATen/native/Repeat.cpp @@ -0,0 +1,47 @@ +#include +#include +#include + +static void compute_cpu(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) { + at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) { + for(int64_t i = i_begin; i < i_end; i++) { + int64_t end = cumsum_ptr[i]; + int64_t size = repeat_ptr[i]; + int64_t start = end - size; + for(int64_t j = start; j < end; j++) { + result_ptr[j] = i; + } + } + }); +} + +namespace at { namespace native { + +Tensor repeat_interleave_cpu(const Tensor &repeat) { + return repeat_interleave_common(repeat); +} + +Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optional dim) { + Tensor input = self; + if(!dim) { + input = self.flatten(); + dim = 0; + } + + Tensor repeats_ = repeats; + if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) { + repeats_ = repeats.reshape({1}).expand({input.size(dim.value())}); + } else if (repeats.dim() == 1) { + AT_CHECK(repeats.size(0) == input.size(dim.value()), "repeats must have the same size as input along dim") + } else { + AT_ERROR("repeats must be 0-dim or 1-dim tensor"); + } + + return input.index_select(dim.value(), at::repeat_interleave(repeats_)); +} + +Tensor repeat_interleave(const Tensor &self, int64_t repeats, c10::optional dim) { + return at::native::repeat_interleave(self, at::tensor({repeats}, self.options().dtype(kLong)), dim); +} + +}} diff --git a/aten/src/ATen/native/Repeat.h b/aten/src/ATen/native/Repeat.h new file mode 100644 index 0000000..a1ba075 --- /dev/null +++ b/aten/src/ATen/native/Repeat.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace at { namespace native { + +template +static inline Tensor repeat_interleave_common(const Tensor &repeats) { + AT_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat"); + AT_CHECK(repeats.scalar_type() == at::kLong, "repeats has to be Long tensor"); + AT_CHECK((repeats >= 0).all().item(), "repeats can not be negative"); + Tensor repeats_ = repeats.contiguous(); + Tensor cumsum = repeats.cumsum(0); + int64_t total = cumsum[-1].item(); + Tensor result = at::empty({total}, repeats.options()); + int64_t *repeat_ptr = repeats_.data(); + int64_t *cumsum_ptr = cumsum.data(); + int64_t *result_ptr = result.data(); + compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0)); + return result; +} + +}} diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu new file mode 100644 index 0000000..f7b783e --- /dev/null +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -0,0 +1,29 @@ +#include +#include + +__global__ static void compute_cuda_kernel(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride = blockDim.x * gridDim.x; + for (int64_t i = idx; i < size; i += stride) { + int64_t end = cumsum_ptr[i]; + int64_t repeat = repeat_ptr[i]; + int64_t start = end - repeat; + for(int64_t j = start; j < end; j++) { + result_ptr[j] = i; + } + } +} + +static void compute_cuda(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) { + int64_t block = 512; + int64_t grid = std::min((size + block - 1) / block, 2048L); + compute_cuda_kernel<<>>(repeat_ptr, cumsum_ptr, result_ptr, size); +} + +namespace at { namespace native { + +Tensor repeat_interleave_cuda(const Tensor &repeat) { + return repeat_interleave_common(repeat); +} + +}} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8623999..be6825e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1822,6 +1822,21 @@ matches_jit_signature: True variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. +- func: repeat_interleave(Tensor repeats) -> Tensor + matches_jit_signature: True + variants: function + dispatch: + CPU: repeat_interleave_cpu + CUDA: repeat_interleave_cuda + +- func: repeat_interleave(Tensor self, Tensor repeats, int? dim=None) -> Tensor + matches_jit_signature: True + variants: function, method + +- func: repeat_interleave(Tensor self, int repeats, int? dim=None) -> Tensor + matches_jit_signature: True + variants: function, method + - func: reshape(Tensor self, int[] shape) -> Tensor matches_jit_signature: True variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 68d6997..c1dcf8b 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -367,6 +367,7 @@ view of a storage and defines numeric operations on it. .. automethod:: renorm .. automethod:: renorm_ .. automethod:: repeat + .. automethod:: repeat_interleave .. automethod:: requires_grad .. automethod:: requires_grad_ .. automethod:: reshape diff --git a/docs/source/torch.rst b/docs/source/torch.rst index b34f188..85bf6d5 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -280,6 +280,7 @@ Other Operations .. autofunction:: histc .. autofunction:: meshgrid .. autofunction:: renorm +.. autofunction:: repeat_interleave .. autofunction:: roll .. autofunction:: tensordot .. autofunction:: trace diff --git a/test/test_torch.py b/test/test_torch.py index ef91b84..1eeb5a8 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8274,6 +8274,47 @@ class _TestTorchMixin(object): self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage') self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)') + def test_repeat_interleave(self): + x = torch.tensor([0, 1, 2, 3]) + expected = torch.tensor([1, 2, 2, 3, 3, 3]) + self.assertEqual(torch.repeat_interleave(x), expected) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.arange(4).reshape(2, 2)) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.arange(4.0)) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4])) + + y = torch.tensor([[1, 2], [3, 4]]) + + y1_v1 = torch.repeat_interleave(y, 2) + y1_v2 = torch.repeat_interleave(y, torch.tensor(2)) + y1_v3 = torch.repeat_interleave(y, torch.tensor([2])) + y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4]) + self.assertEqual(y1_v1, y1_expect) + self.assertEqual(y1_v2, y1_expect) + self.assertEqual(y1_v3, y1_expect) + + y2 = torch.repeat_interleave(y, 3, dim=1) + y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + self.assertEqual(y2, y2_expect) + + y3 = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + y3_expect = torch.tensor([[1, 2], + [3, 4], + [3, 4]]) + self.assertEqual(y3, y3_expect) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(y, torch.tensor([1, 2, 3]), dim=0) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_repeat_tile(self): diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index db58fbb..da61942 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1873,6 +1873,7 @@ Unlike :meth:`~Tensor.expand`, this function copies the tensor's data. `numpy.repeat `_, but is more similar to `numpy.tile `_. + For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`. Args: sizes (torch.Size or int...): The number of times to repeat this tensor along each @@ -1890,6 +1891,13 @@ Example:: torch.Size([4, 2, 3]) """) +add_docstr_all('repeat_interleave', + r""" +repeat_interleave(repeats, dim=None) -> Tensor + +See :func:`torch.repeat_interleave`. +""") + add_docstr_all('requires_grad_', r""" requires_grad_(requires_grad=True) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 580481f..4ac08dc 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -6396,3 +6396,49 @@ Example:: [2, 3], [3, 3]]) """) + + +add_docstr(torch.repeat_interleave, + r""" +.. function:: repeat_interleave(input, repeats, dim=None) -> Tensor + +Repeat elements of a tensor. + +.. warning:: + + This is different from :func:`torch.repeat` but similar to `numpy.repeat`. + +Args: + input (Tensor): The input tensor + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + +Returns: + Tensor: Repeated tensor which has the same shape as input, except along the + given axis. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + +.. function:: repeat_interleave(repeats) -> Tensor + +If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be +`tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, +`1` appears `n2` times, `2` appears `n3` times, etc. +""") -- 2.7.4