Add `itertools.{prod, combinations, combinations_with_replacement}` like op to pytorc...
authorXiang Gao <qasdfgtyuiop@gmail.com>
Tue, 15 Jan 2019 16:24:27 +0000 (08:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 16:31:22 +0000 (08:31 -0800)
Summary:
closes https://github.com/pytorch/pytorch/issues/7580
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9393

Differential Revision: D13659628

Pulled By: zou3519

fbshipit-source-id: 3a233befa785709395a793ba8833413be394a6fd

aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/Itertools.cpp [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
test/test_torch.py
torch/_torch_docs.py
torch/functional.py

index 72f07ce..61399e7 100644 (file)
@@ -236,6 +236,7 @@ _(aten, broadcast_tensors) \
 _(aten, btrifact) \
 _(aten, btrifact_with_info) \
 _(aten, btrisolve) \
+_(aten, cartesian_prod) \
 _(aten, cat) \
 _(aten, cauchy) \
 _(aten, ceil) \
@@ -249,6 +250,7 @@ _(aten, clamp_max) \
 _(aten, clamp_min) \
 _(aten, clone) \
 _(aten, coalesce) \
+_(aten, combinations) \
 _(aten, constant_pad_nd) \
 _(aten, contiguous) \
 _(aten, conv1d) \
diff --git a/aten/src/ATen/native/Itertools.cpp b/aten/src/ATen/native/Itertools.cpp
new file mode 100644 (file)
index 0000000..839d69b
--- /dev/null
@@ -0,0 +1,60 @@
+#include "ATen/ATen.h"
+#include "ATen/Dispatch.h"
+
+#include <vector>
+
+namespace {
+
+using namespace at;
+
+Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
+  // get a mask that has value 1 whose indices satisfies i < j < k < ...
+  // or i <= j <= k <= ... (depending on diagonal)
+  Tensor range = at::arange(n, opt.dtype(kLong));
+  std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range));
+  Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte));
+  if(diagonal) {
+    for(int64_t i = 0; i < dims - 1; i++) {
+      mask *= index_grids[i] <= index_grids[i+1];
+    }
+  } else {
+    for(int64_t i = 0; i < dims - 1; i++) {
+      mask *= index_grids[i] < index_grids[i+1];
+    }
+  }
+  return mask;
+}
+
+}  // namespace
+
+namespace at {
+namespace native{
+
+Tensor cartesian_prod(TensorList tensors) {
+  for(const Tensor &t : tensors) {
+    AT_CHECK(t.dim() == 1, "Expect a 1D vector, but got shape ", t.sizes());
+  }
+  if (tensors.size() == 1) {
+    return tensors[0];
+  }
+  std::vector<Tensor> grids = at::meshgrid(tensors);
+  for(Tensor &t : grids) {
+    t = t.flatten();
+  }
+  return at::stack(grids, 1);
+}
+
+Tensor combinations(const Tensor& self, int64_t r, bool with_replacement) {
+  AT_CHECK(self.dim() == 1, "Expect a 1D vector, but got shape ", self.sizes());
+  AT_CHECK(r > 0, "Expect a positive number, but got ", r);
+  int64_t num_elements = self.numel();
+  std::vector<Tensor> grids = at::meshgrid(std::vector<Tensor>(r, self));
+  Tensor mask = _triu_mask(num_elements, r, with_replacement, self.options());
+  for(Tensor &t : grids) {
+    t = t.masked_select(mask);
+  }
+  return at::stack(grids, 1);
+}
+
+}  // namespace native
+}  // namespace at
index 2144a88..7a87353 100644 (file)
 
 - func: meshgrid(TensorList tensors) -> TensorList
 
+- func: cartesian_prod(TensorList tensors) -> Tensor
+  variants: function
+
+- func: combinations(Tensor self, int64_t r=2, bool with_replacement=false) -> Tensor
+  variants: function
+
 - func: item(Tensor self) -> Scalar
   variants: method
 
index c1e388d..fbc2fef 100644 (file)
@@ -19,7 +19,7 @@ from torch._utils_internal import get_file_path, get_file_path_2
 from torch.utils.dlpack import from_dlpack, to_dlpack
 from torch._utils import _rebuild_tensor
 from torch._six import inf, nan, string_classes
-from itertools import product, combinations
+from itertools import product, combinations, combinations_with_replacement
 from functools import reduce
 from torch import multiprocessing as mp
 from common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
@@ -9688,6 +9688,63 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         self.assertEqual(b.type(), b_copy.type())
         self.assertEqual(b.data.type(), b_copy.type())
 
+    def test_cartesian_prod(self):
+        a = torch.tensor([1])
+        b = torch.tensor([1, 2, 3])
+        c = torch.tensor([1, 2])
+        prod = torch.cartesian_prod(a, b, c)
+        expected = torch.tensor(list(product([a], b, c)))
+        self.assertEqual(expected, prod)
+
+        # test 0 size input
+        d = torch.empty(0, dtype=b.dtype)
+        prod = torch.cartesian_prod(a, b, c, d)
+        expected = torch.empty(0, 4, dtype=b.dtype)
+        self.assertEqual(expected, prod)
+
+        # test single input
+        prod = torch.cartesian_prod(b)
+        self.assertEqual(b, prod)
+
+    def test_combinations(self):
+        a = torch.tensor([1, 2, 3])
+
+        c = torch.combinations(a, r=1)
+        expected = torch.tensor(list(combinations(a, r=1)))
+        self.assertEqual(c, expected)
+
+        c = torch.combinations(a, r=1, with_replacement=True)
+        expected = torch.tensor(list(combinations_with_replacement(a, r=1)))
+        self.assertEqual(c, expected)
+
+        c = torch.combinations(a)
+        expected = torch.tensor(list(combinations(a, r=2)))
+        self.assertEqual(c, expected)
+
+        c = torch.combinations(a, with_replacement=True)
+        expected = torch.tensor(list(combinations_with_replacement(a, r=2)))
+        self.assertEqual(c, expected)
+
+        c = torch.combinations(a, r=3)
+        expected = torch.tensor(list(combinations(a, r=3)))
+        self.assertEqual(c, expected)
+
+        c = torch.combinations(a, r=4)
+        expected = torch.empty(0, 4, dtype=a.dtype)
+        self.assertEqual(c, expected)
+
+        c = torch.combinations(a, r=5)
+        expected = torch.empty(0, 5, dtype=a.dtype)
+        self.assertEqual(c, expected)
+
+        # test empty imput
+        a = torch.empty(0)
+        c1 = torch.combinations(a)
+        c2 = torch.combinations(a, with_replacement=True)
+        expected = torch.empty(0, 2, dtype=a.dtype)
+        self.assertEqual(c1, expected)
+        self.assertEqual(c2, expected)
+
     @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected')
     def test_reverse_binary_ops_multiple_device(self):
         self.assertEqual(2 + torch.tensor(3), 2 + torch.tensor(3).to("cuda:1"))    # __radd__
index f1a475e..835958c 100644 (file)
@@ -6267,3 +6267,47 @@ Example::
     >>>                            [7, 8, 9]]))
     (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9]))
 """)
+
+
+add_docstr(torch.combinations,
+           r"""
+combinations(tensor, r=2, with_replacement=False) -> seq
+
+Compute combinations of length :math:`r` of the given tensor. The behavior is similar to
+python's `itertools.combinations` when `with_replacement` is set to `False`, and
+`itertools.combinations_with_replacement` when `with_replacement` is set to `True`.
+
+Arguments:
+    tensor (Tensor): 1D vector.
+    r (int, optional): number of elements to combine
+    with_replacement (boolean, optional): whether to allow duplication in combination
+
+Returns:
+    Tensor: A tensor equivalent to converting all the input tensors into lists, do
+    `itertools.combinations` or `itertools.combinations_with_replacement` on these
+    lists, and finally convert the resulting list into tensor.
+
+Example::
+
+    >>> a = [1, 2, 3]
+    >>> list(itertools.combinations(a, r=2))
+    [(1, 2), (1, 3), (2, 3)]
+    >>> list(itertools.combinations(a, r=3))
+    [(1, 2, 3)]
+    >>> list(itertools.combinations_with_replacement(a, r=2))
+    [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
+    >>> tensor_a = torch.tensor(a)
+    >>> torch.combinations(tensor_a)
+    tensor([[1, 2],
+            [1, 3],
+            [2, 3]])
+    >>> torch.combinations(tensor_a, r=3)
+    tensor([[1, 2, 3]])
+    >>> torch.combinations(tensor_a, with_replacement=True)
+    tensor([[1, 1],
+            [1, 2],
+            [1, 3],
+            [2, 2],
+            [2, 3],
+            [3, 3]])
+""")
index 197509e..1ff28b1 100644 (file)
@@ -27,6 +27,7 @@ __all__ = [
     'stft',
     'tensordot',
     'unique',
+    'cartesian_prod',
 ]
 
 
@@ -601,6 +602,37 @@ def argsort(input, dim=None, descending=False):
     return torch.sort(input, dim, descending)[1]
 
 
+def cartesian_prod(*tensors):
+    """Do cartesian product of the given sequence of tensors. The behavior is similar to
+    python's `itertools.product`.
+
+    Arguments:
+        *tensors: any number of 1 dimensional tensors.
+
+    Returns:
+        Tensor: A tensor equivalent to converting all the input tensors into lists,
+            do `itertools.product` on these lists, and finally convert the resulting list
+            into tensor.
+
+    Example::
+
+        >>> a = [1, 2, 3]
+        >>> b = [4, 5]
+        >>> list(itertools.product(a, b))
+        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
+        >>> tensor_a = torch.tensor(a)
+        >>> tensor_b = torch.tensor(b)
+        >>> torch.cartesian_prod(tensor_a, tensor_b)
+        tensor([[1, 4],
+                [1, 5],
+                [2, 4],
+                [2, 5],
+                [3, 4],
+                [3, 5]])
+    """
+    return torch._C._VariableFunctions.cartesian_prod(tensors)
+
+
 def norm(input, p="fro", dim=None, keepdim=False, out=None):
     r"""Returns the matrix norm or vector norm of a given tensor.