Tensor pin_memory() const;
Tensor pinverse(double rcond=1e-15) const;
Tensor repeat(IntArrayRef repeats) const;
+ Tensor repeat_interleave(const Tensor & repeats, c10::optional<int64_t> dim=c10::nullopt) const;
+ Tensor repeat_interleave(int64_t repeats, c10::optional<int64_t> dim=c10::nullopt) const;
Tensor reshape(IntArrayRef shape) const;
Tensor reshape_as(const Tensor & other) const;
Tensor round() const;
inline Tensor Tensor::repeat(IntArrayRef repeats) const {
return dispatch_type().repeat(*this, repeats);
}
+inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional<int64_t> dim) const {
+ return dispatch_type().repeat_interleave(*this, repeats, dim);
+}
+inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional<int64_t> dim) const {
+ return dispatch_type().repeat_interleave(*this, repeats, dim);
+}
inline Tensor Tensor::reshape(IntArrayRef shape) const {
return dispatch_type().reshape(*this, shape);
}
virtual Tensor pin_memory(const Tensor & self) const = 0;
virtual Tensor pinverse(const Tensor & self, double rcond) const = 0;
virtual Tensor repeat(const Tensor & self, IntArrayRef repeats) const = 0;
+ virtual Tensor repeat_interleave(const Tensor & repeats) const = 0;
+ virtual Tensor repeat_interleave(const Tensor & self, const Tensor & repeats, c10::optional<int64_t> dim) const = 0;
+ virtual Tensor repeat_interleave(const Tensor & self, int64_t repeats, c10::optional<int64_t> dim) const = 0;
virtual Tensor reshape(const Tensor & self, IntArrayRef shape) const = 0;
virtual Tensor reshape_as(const Tensor & self, const Tensor & other) const = 0;
virtual Tensor round(const Tensor & self) const = 0;
--- /dev/null
+#include <ATen/ATen.h>
+#include <ATen/native/Repeat.h>
+#include <ATen/Parallel.h>
+
+static void compute_cpu(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
+ at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
+ for(int64_t i = i_begin; i < i_end; i++) {
+ int64_t end = cumsum_ptr[i];
+ int64_t size = repeat_ptr[i];
+ int64_t start = end - size;
+ for(int64_t j = start; j < end; j++) {
+ result_ptr[j] = i;
+ }
+ }
+ });
+}
+
+namespace at { namespace native {
+
+Tensor repeat_interleave_cpu(const Tensor &repeat) {
+ return repeat_interleave_common<compute_cpu>(repeat);
+}
+
+Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optional<int64_t> dim) {
+ Tensor input = self;
+ if(!dim) {
+ input = self.flatten();
+ dim = 0;
+ }
+
+ Tensor repeats_ = repeats;
+ if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) {
+ repeats_ = repeats.reshape({1}).expand({input.size(dim.value())});
+ } else if (repeats.dim() == 1) {
+ AT_CHECK(repeats.size(0) == input.size(dim.value()), "repeats must have the same size as input along dim")
+ } else {
+ AT_ERROR("repeats must be 0-dim or 1-dim tensor");
+ }
+
+ return input.index_select(dim.value(), at::repeat_interleave(repeats_));
+}
+
+Tensor repeat_interleave(const Tensor &self, int64_t repeats, c10::optional<int64_t> dim) {
+ return at::native::repeat_interleave(self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
+}
+
+}}
--- /dev/null
+#pragma once
+
+#include <ATen/ATen.h>
+
+namespace at { namespace native {
+
+template <void compute(int64_t *, int64_t *, int64_t *, int64_t)>
+static inline Tensor repeat_interleave_common(const Tensor &repeats) {
+ AT_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
+ AT_CHECK(repeats.scalar_type() == at::kLong, "repeats has to be Long tensor");
+ AT_CHECK((repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
+ Tensor repeats_ = repeats.contiguous();
+ Tensor cumsum = repeats.cumsum(0);
+ int64_t total = cumsum[-1].item<int64_t>();
+ Tensor result = at::empty({total}, repeats.options());
+ int64_t *repeat_ptr = repeats_.data<int64_t>();
+ int64_t *cumsum_ptr = cumsum.data<int64_t>();
+ int64_t *result_ptr = result.data<int64_t>();
+ compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0));
+ return result;
+}
+
+}}
--- /dev/null
+#include <ATen/ATen.h>
+#include <ATen/native/Repeat.h>
+
+__global__ static void compute_cuda_kernel(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
+ int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ int64_t stride = blockDim.x * gridDim.x;
+ for (int64_t i = idx; i < size; i += stride) {
+ int64_t end = cumsum_ptr[i];
+ int64_t repeat = repeat_ptr[i];
+ int64_t start = end - repeat;
+ for(int64_t j = start; j < end; j++) {
+ result_ptr[j] = i;
+ }
+ }
+}
+
+static void compute_cuda(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
+ int64_t block = 512;
+ int64_t grid = std::min<int64_t>((size + block - 1) / block, 2048L);
+ compute_cuda_kernel<<<grid, block>>>(repeat_ptr, cumsum_ptr, result_ptr, size);
+}
+
+namespace at { namespace native {
+
+Tensor repeat_interleave_cuda(const Tensor &repeat) {
+ return repeat_interleave_common<compute_cuda>(repeat);
+}
+
+}}
matches_jit_signature: True
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
+- func: repeat_interleave(Tensor repeats) -> Tensor
+ matches_jit_signature: True
+ variants: function
+ dispatch:
+ CPU: repeat_interleave_cpu
+ CUDA: repeat_interleave_cuda
+
+- func: repeat_interleave(Tensor self, Tensor repeats, int? dim=None) -> Tensor
+ matches_jit_signature: True
+ variants: function, method
+
+- func: repeat_interleave(Tensor self, int repeats, int? dim=None) -> Tensor
+ matches_jit_signature: True
+ variants: function, method
+
- func: reshape(Tensor self, int[] shape) -> Tensor
matches_jit_signature: True
variants: function, method
.. automethod:: renorm
.. automethod:: renorm_
.. automethod:: repeat
+ .. automethod:: repeat_interleave
.. automethod:: requires_grad
.. automethod:: requires_grad_
.. automethod:: reshape
.. autofunction:: histc
.. autofunction:: meshgrid
.. autofunction:: renorm
+.. autofunction:: repeat_interleave
.. autofunction:: roll
.. autofunction:: tensordot
.. autofunction:: trace
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)')
+ def test_repeat_interleave(self):
+ x = torch.tensor([0, 1, 2, 3])
+ expected = torch.tensor([1, 2, 2, 3, 3, 3])
+ self.assertEqual(torch.repeat_interleave(x), expected)
+
+ with self.assertRaises(RuntimeError):
+ torch.repeat_interleave(torch.arange(4).reshape(2, 2))
+
+ with self.assertRaises(RuntimeError):
+ torch.repeat_interleave(torch.arange(4.0))
+
+ with self.assertRaises(RuntimeError):
+ torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4]))
+
+ y = torch.tensor([[1, 2], [3, 4]])
+
+ y1_v1 = torch.repeat_interleave(y, 2)
+ y1_v2 = torch.repeat_interleave(y, torch.tensor(2))
+ y1_v3 = torch.repeat_interleave(y, torch.tensor([2]))
+ y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4])
+ self.assertEqual(y1_v1, y1_expect)
+ self.assertEqual(y1_v2, y1_expect)
+ self.assertEqual(y1_v3, y1_expect)
+
+ y2 = torch.repeat_interleave(y, 3, dim=1)
+ y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
+ [3, 3, 3, 4, 4, 4]])
+ self.assertEqual(y2, y2_expect)
+
+ y3 = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
+ y3_expect = torch.tensor([[1, 2],
+ [3, 4],
+ [3, 4]])
+ self.assertEqual(y3, y3_expect)
+
+ with self.assertRaises(RuntimeError):
+ torch.repeat_interleave(y, torch.tensor([1, 2, 3]), dim=0)
+
+ with self.assertRaises(RuntimeError):
+ torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0)
+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_repeat_tile(self):
`numpy.repeat <https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html>`_,
but is more similar to
`numpy.tile <https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html>`_.
+ For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`.
Args:
sizes (torch.Size or int...): The number of times to repeat this tensor along each
torch.Size([4, 2, 3])
""")
+add_docstr_all('repeat_interleave',
+ r"""
+repeat_interleave(repeats, dim=None) -> Tensor
+
+See :func:`torch.repeat_interleave`.
+""")
+
add_docstr_all('requires_grad_',
r"""
requires_grad_(requires_grad=True) -> Tensor
[2, 3],
[3, 3]])
""")
+
+
+add_docstr(torch.repeat_interleave,
+ r"""
+.. function:: repeat_interleave(input, repeats, dim=None) -> Tensor
+
+Repeat elements of a tensor.
+
+.. warning::
+
+ This is different from :func:`torch.repeat` but similar to `numpy.repeat`.
+
+Args:
+ input (Tensor): The input tensor
+ repeats (Tensor or int): The number of repetitions for each element.
+ repeats is broadcasted to fit the shape of the given axis.
+ dim (int, optional): The dimension along which to repeat values.
+ By default, use the flattened input array, and return a flat output
+ array.
+
+Returns:
+ Tensor: Repeated tensor which has the same shape as input, except along the
+ given axis.
+
+Example::
+
+ >>> x = torch.tensor([1, 2, 3])
+ >>> x.repeat_interleave(2)
+ tensor([1, 1, 2, 2, 3, 3])
+ >>> y = torch.tensor([[1, 2], [3, 4]])
+ >>> torch.repeat_interleave(y, 2)
+ tensor([1, 1, 2, 2, 3, 3, 4, 4])
+ >>> torch.repeat_interleave(y, 3, dim=1)
+ tensor([[1, 1, 1, 2, 2, 2],
+ [3, 3, 3, 4, 4, 4]])
+ >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
+ tensor([[1, 2],
+ [3, 4],
+ [3, 4]])
+
+.. function:: repeat_interleave(repeats) -> Tensor
+
+If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be
+`tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times,
+`1` appears `n2` times, `2` appears `n3` times, etc.
+""")