Add numpy like repeat as torch.repeat_interleave (#18395)
authorGao, Xiang <qasdfgtyuiop@gmail.com>
Sat, 6 Apr 2019 01:13:39 +0000 (18:13 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 6 Apr 2019 01:16:25 +0000 (18:16 -0700)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/14093
cc: SsnL
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18395

Differential Revision: D14599509

Pulled By: umanwizard

fbshipit-source-id: 2391a1cc135fe5bab38475f1c8ed87c4a96222f3

12 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/Repeat.cpp [new file with mode: 0644]
aten/src/ATen/native/Repeat.h [new file with mode: 0644]
aten/src/ATen/native/cuda/Repeat.cu [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
test/test_torch.py
torch/_tensor_docs.py
torch/_torch_docs.py

index a171eab..e00e6ff 100644 (file)
@@ -466,6 +466,8 @@ class CAFFE2_API Tensor {
   Tensor pin_memory() const;
   Tensor pinverse(double rcond=1e-15) const;
   Tensor repeat(IntArrayRef repeats) const;
+  Tensor repeat_interleave(const Tensor & repeats, c10::optional<int64_t> dim=c10::nullopt) const;
+  Tensor repeat_interleave(int64_t repeats, c10::optional<int64_t> dim=c10::nullopt) const;
   Tensor reshape(IntArrayRef shape) const;
   Tensor reshape_as(const Tensor & other) const;
   Tensor round() const;
index efe387d..4096f15 100644 (file)
@@ -454,6 +454,12 @@ inline Tensor Tensor::pinverse(double rcond) const {
 inline Tensor Tensor::repeat(IntArrayRef repeats) const {
     return dispatch_type().repeat(*this, repeats);
 }
+inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional<int64_t> dim) const {
+    return dispatch_type().repeat_interleave(*this, repeats, dim);
+}
+inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional<int64_t> dim) const {
+    return dispatch_type().repeat_interleave(*this, repeats, dim);
+}
 inline Tensor Tensor::reshape(IntArrayRef shape) const {
     return dispatch_type().reshape(*this, shape);
 }
index 9a1d6fc..0eea7bb 100644 (file)
@@ -346,6 +346,9 @@ struct CAFFE2_API Type {
   virtual Tensor pin_memory(const Tensor & self) const = 0;
   virtual Tensor pinverse(const Tensor & self, double rcond) const = 0;
   virtual Tensor repeat(const Tensor & self, IntArrayRef repeats) const = 0;
+  virtual Tensor repeat_interleave(const Tensor & repeats) const = 0;
+  virtual Tensor repeat_interleave(const Tensor & self, const Tensor & repeats, c10::optional<int64_t> dim) const = 0;
+  virtual Tensor repeat_interleave(const Tensor & self, int64_t repeats, c10::optional<int64_t> dim) const = 0;
   virtual Tensor reshape(const Tensor & self, IntArrayRef shape) const = 0;
   virtual Tensor reshape_as(const Tensor & self, const Tensor & other) const = 0;
   virtual Tensor round(const Tensor & self) const = 0;
diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp
new file mode 100644 (file)
index 0000000..0137c80
--- /dev/null
@@ -0,0 +1,47 @@
+#include <ATen/ATen.h>
+#include <ATen/native/Repeat.h>
+#include <ATen/Parallel.h>
+
+static void compute_cpu(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
+    at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
+        for(int64_t i = i_begin; i < i_end; i++) {
+            int64_t end = cumsum_ptr[i];
+            int64_t size = repeat_ptr[i];
+            int64_t start = end - size;
+            for(int64_t j = start; j < end; j++) {
+                result_ptr[j] = i;
+            }
+        }
+    });
+}
+
+namespace at { namespace native {
+
+Tensor repeat_interleave_cpu(const Tensor &repeat) {
+    return repeat_interleave_common<compute_cpu>(repeat);
+}
+
+Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optional<int64_t> dim) {
+    Tensor input = self;
+    if(!dim) {
+        input = self.flatten();
+        dim = 0;
+    }
+
+    Tensor repeats_ = repeats;
+    if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) {
+        repeats_ = repeats.reshape({1}).expand({input.size(dim.value())});
+    } else if (repeats.dim() == 1) {
+        AT_CHECK(repeats.size(0) == input.size(dim.value()), "repeats must have the same size as input along dim")
+    } else {
+        AT_ERROR("repeats must be 0-dim or 1-dim tensor");
+    }
+
+    return input.index_select(dim.value(), at::repeat_interleave(repeats_));
+}
+
+Tensor repeat_interleave(const Tensor &self, int64_t repeats, c10::optional<int64_t> dim) {
+    return at::native::repeat_interleave(self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
+}
+
+}}
diff --git a/aten/src/ATen/native/Repeat.h b/aten/src/ATen/native/Repeat.h
new file mode 100644 (file)
index 0000000..a1ba075
--- /dev/null
@@ -0,0 +1,23 @@
+#pragma once
+
+#include <ATen/ATen.h>
+
+namespace at { namespace native {
+
+template <void compute(int64_t *, int64_t *, int64_t *, int64_t)>
+static inline Tensor repeat_interleave_common(const Tensor &repeats) {
+    AT_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
+    AT_CHECK(repeats.scalar_type() == at::kLong, "repeats has to be Long tensor");
+    AT_CHECK((repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
+    Tensor repeats_ = repeats.contiguous();
+    Tensor cumsum = repeats.cumsum(0);
+    int64_t total = cumsum[-1].item<int64_t>();
+    Tensor result = at::empty({total}, repeats.options());
+    int64_t *repeat_ptr = repeats_.data<int64_t>();
+    int64_t *cumsum_ptr = cumsum.data<int64_t>();
+    int64_t *result_ptr = result.data<int64_t>();
+    compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0));
+    return result;
+}
+
+}}
diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu
new file mode 100644 (file)
index 0000000..f7b783e
--- /dev/null
@@ -0,0 +1,29 @@
+#include <ATen/ATen.h>
+#include <ATen/native/Repeat.h>
+
+__global__ static void compute_cuda_kernel(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
+    int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+    int64_t stride = blockDim.x * gridDim.x;
+    for (int64_t i = idx; i < size; i += stride) {
+        int64_t end = cumsum_ptr[i];
+        int64_t repeat = repeat_ptr[i];
+        int64_t start = end - repeat;
+        for(int64_t j = start; j < end; j++) {
+            result_ptr[j] = i;
+        }
+    }
+}
+
+static void compute_cuda(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
+    int64_t block = 512;
+    int64_t grid = std::min<int64_t>((size + block - 1) / block, 2048L);
+    compute_cuda_kernel<<<grid, block>>>(repeat_ptr, cumsum_ptr, result_ptr, size);
+}
+
+namespace at { namespace native {
+
+Tensor repeat_interleave_cuda(const Tensor &repeat) {
+    return repeat_interleave_common<compute_cuda>(repeat);
+}
+
+}}
index 8623999..be6825e 100644 (file)
   matches_jit_signature: True
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
 
+- func: repeat_interleave(Tensor repeats) -> Tensor
+  matches_jit_signature: True
+  variants: function
+  dispatch:
+    CPU: repeat_interleave_cpu
+    CUDA: repeat_interleave_cuda
+
+- func: repeat_interleave(Tensor self, Tensor repeats, int? dim=None) -> Tensor
+  matches_jit_signature: True
+  variants: function, method
+
+- func: repeat_interleave(Tensor self, int repeats, int? dim=None) -> Tensor
+  matches_jit_signature: True
+  variants: function, method
+
 - func: reshape(Tensor self, int[] shape) -> Tensor
   matches_jit_signature: True
   variants: function, method
index 68d6997..c1dcf8b 100644 (file)
@@ -367,6 +367,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: renorm
    .. automethod:: renorm_
    .. automethod:: repeat
+   .. automethod:: repeat_interleave
    .. automethod:: requires_grad
    .. automethod:: requires_grad_
    .. automethod:: reshape
index b34f188..85bf6d5 100644 (file)
@@ -280,6 +280,7 @@ Other Operations
 .. autofunction:: histc
 .. autofunction:: meshgrid
 .. autofunction:: renorm
+.. autofunction:: repeat_interleave
 .. autofunction:: roll
 .. autofunction:: tensordot
 .. autofunction:: trace
index ef91b84..1eeb5a8 100644 (file)
@@ -8274,6 +8274,47 @@ class _TestTorchMixin(object):
         self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
         self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)')
 
+    def test_repeat_interleave(self):
+        x = torch.tensor([0, 1, 2, 3])
+        expected = torch.tensor([1, 2, 2, 3, 3, 3])
+        self.assertEqual(torch.repeat_interleave(x), expected)
+
+        with self.assertRaises(RuntimeError):
+            torch.repeat_interleave(torch.arange(4).reshape(2, 2))
+
+        with self.assertRaises(RuntimeError):
+            torch.repeat_interleave(torch.arange(4.0))
+
+        with self.assertRaises(RuntimeError):
+            torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4]))
+
+        y = torch.tensor([[1, 2], [3, 4]])
+
+        y1_v1 = torch.repeat_interleave(y, 2)
+        y1_v2 = torch.repeat_interleave(y, torch.tensor(2))
+        y1_v3 = torch.repeat_interleave(y, torch.tensor([2]))
+        y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4])
+        self.assertEqual(y1_v1, y1_expect)
+        self.assertEqual(y1_v2, y1_expect)
+        self.assertEqual(y1_v3, y1_expect)
+
+        y2 = torch.repeat_interleave(y, 3, dim=1)
+        y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
+                                  [3, 3, 3, 4, 4, 4]])
+        self.assertEqual(y2, y2_expect)
+
+        y3 = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
+        y3_expect = torch.tensor([[1, 2],
+                                  [3, 4],
+                                  [3, 4]])
+        self.assertEqual(y3, y3_expect)
+
+        with self.assertRaises(RuntimeError):
+            torch.repeat_interleave(y, torch.tensor([1, 2, 3]), dim=0)
+
+        with self.assertRaises(RuntimeError):
+            torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0)
+
     @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
     def test_repeat_tile(self):
 
index db58fbb..da61942 100644 (file)
@@ -1873,6 +1873,7 @@ Unlike :meth:`~Tensor.expand`, this function copies the tensor's data.
     `numpy.repeat <https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html>`_,
     but is more similar to
     `numpy.tile <https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html>`_.
+    For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`.
 
 Args:
     sizes (torch.Size or int...): The number of times to repeat this tensor along each
@@ -1890,6 +1891,13 @@ Example::
     torch.Size([4, 2, 3])
 """)
 
+add_docstr_all('repeat_interleave',
+               r"""
+repeat_interleave(repeats, dim=None) -> Tensor
+
+See :func:`torch.repeat_interleave`.
+""")
+
 add_docstr_all('requires_grad_',
                r"""
 requires_grad_(requires_grad=True) -> Tensor
index 580481f..4ac08dc 100644 (file)
@@ -6396,3 +6396,49 @@ Example::
             [2, 3],
             [3, 3]])
 """)
+
+
+add_docstr(torch.repeat_interleave,
+           r"""
+.. function:: repeat_interleave(input, repeats, dim=None) -> Tensor
+
+Repeat elements of a tensor.
+
+.. warning::
+
+    This is different from :func:`torch.repeat` but similar to `numpy.repeat`.
+
+Args:
+    input (Tensor): The input tensor
+    repeats (Tensor or int): The number of repetitions for each element.
+        repeats is broadcasted to fit the shape of the given axis.
+    dim (int, optional): The dimension along which to repeat values.
+        By default, use the flattened input array, and return a flat output
+        array.
+
+Returns:
+    Tensor: Repeated tensor which has the same shape as input, except along the
+     given axis.
+
+Example::
+
+    >>> x = torch.tensor([1, 2, 3])
+    >>> x.repeat_interleave(2)
+    tensor([1, 1, 2, 2, 3, 3])
+    >>> y = torch.tensor([[1, 2], [3, 4]])
+    >>> torch.repeat_interleave(y, 2)
+    tensor([1, 1, 2, 2, 3, 3, 4, 4])
+    >>> torch.repeat_interleave(y, 3, dim=1)
+    tensor([[1, 1, 1, 2, 2, 2],
+            [3, 3, 3, 4, 4, 4]])
+    >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
+    tensor([[1, 2],
+            [3, 4],
+            [3, 4]])
+
+.. function:: repeat_interleave(repeats) -> Tensor
+
+If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be
+`tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times,
+`1` appears `n2` times, `2` appears `n3` times, etc.
+""")