Move argsort to C++
authorGao, Xiang <qasdfgtyuiop@gmail.com>
Thu, 21 Feb 2019 15:50:27 +0000 (07:50 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Feb 2019 15:59:27 +0000 (07:59 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17099

Differential Revision: D14165671

Pulled By: ezyang

fbshipit-source-id: 3871de6874fe09871ebd9b8943c13c9af325bf33

aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/native_functions.yaml
torch/__init__.pyi.in
torch/_tensor_docs.py
torch/_torch_docs.py
torch/functional.py
torch/tensor.py

index 6c337fd..0682f69 100644 (file)
@@ -706,6 +706,7 @@ class CAFFE2_API Tensor {
   Tensor max() const;
   Tensor median() const;
   std::tuple<Tensor,Tensor> sort(int64_t dim=-1, bool descending=false) const;
+  Tensor argsort(int64_t dim=-1, bool descending=false) const;
   std::tuple<Tensor,Tensor> topk(int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) const;
   Tensor all() const;
   Tensor any() const;
index 973d905..026b2c3 100644 (file)
@@ -1252,6 +1252,9 @@ inline Tensor Tensor::median() const {
 inline std::tuple<Tensor,Tensor> Tensor::sort(int64_t dim, bool descending) const {
     return type().sort(*this, dim, descending);
 }
+inline Tensor Tensor::argsort(int64_t dim, bool descending) const {
+    return type().argsort(*this, dim, descending);
+}
 inline std::tuple<Tensor,Tensor> Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const {
     return type().topk(*this, k, dim, largest, sorted);
 }
index 1b300f1..825081f 100644 (file)
@@ -590,6 +590,7 @@ struct CAFFE2_API Type {
   virtual Tensor max(const Tensor & self) const = 0;
   virtual Tensor median(const Tensor & self) const = 0;
   virtual std::tuple<Tensor,Tensor> sort(const Tensor & self, int64_t dim, bool descending) const = 0;
+  virtual Tensor argsort(const Tensor & self, int64_t dim, bool descending) const = 0;
   virtual std::tuple<Tensor,Tensor> topk(const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) const = 0;
   virtual Tensor all(const Tensor & self) const = 0;
   virtual Tensor any(const Tensor & self) const = 0;
index b2cd6db..329680d 100644 (file)
@@ -697,6 +697,11 @@ std::tuple<Tensor &,Tensor &> sort_out(Tensor & values, Tensor & indices, const
 std::tuple<Tensor,Tensor> sort(const Tensor & self, int64_t dim, bool descending) {
   return at::legacy::th::_th_sort(self, dim, descending);
 }
+
+Tensor argsort(const Tensor & self, int64_t dim, bool descending) {
+  return std::get<1>(at::legacy::th::_th_sort(self, dim, descending));
+}
+
 std::tuple<Tensor &,Tensor &> topk_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) {
   return at::legacy::th::_th_topk_out(values, indices, self, k, dim, largest, sorted);
 }
index 1780ea4..b79ff46 100644 (file)
   matches_jit_signature: True
   variants: method, function
 
+- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
+  matches_jit_signature: True
+  variants: method, function
+
 - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!), Tensor(b!))
 
 - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor, Tensor)
index aa3c1f0..d38fd74 100644 (file)
@@ -75,7 +75,6 @@ class Tensor:
     # way to not have to write these out again...
     def argmax(self, dim=None, keepdim=False): ...
     def argmin(self, dim=None, keepdim=False): ...
-    def argsort(self, dim=None, descending=False): ...
     def norm(self, p="fro", dim=None, keepdim=False): ...
     def stft(self, n_fft, hop_length=None, win_length=None, window=None,
              center=True, pad_mode='reflect', normalized=False, onesided=True): ...
index bd6a009..4870343 100644 (file)
@@ -2206,11 +2206,18 @@ Example::
 
 add_docstr_all('sort',
                r"""
-sort(dim=None, descending=False) -> (Tensor, LongTensor)
+sort(dim=-1, descending=False) -> (Tensor, LongTensor)
 
 See :func:`torch.sort`
 """)
 
+add_docstr_all('argsort',
+               r"""
+argsort(dim=-1, descending=False) -> LongTensor
+
+See :func: `torch.argsort`
+""")
+
 add_docstr_all('sparse_dim',
                r"""
 sparse_dim() -> int
index 188f842..f5906c1 100644 (file)
@@ -4269,7 +4269,7 @@ Example::
 
 add_docstr(torch.sort,
            r"""
-sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
+sort(input, dim=-1, descending=False, out=None) -> (Tensor, LongTensor)
 
 Sorts the elements of the :attr:`input` tensor along a given dimension
 in ascending order by value.
@@ -4313,6 +4313,38 @@ Example::
             [ 1,  2,  2,  0]])
 """)
 
+add_docstr(torch.argsort,
+           r"""
+argsort(input, dim=-1, descending=False, out=None) -> LongTensor
+
+Returns the indices that sort a tensor along a given dimension in ascending
+order by value.
+
+This is the second value returned by :meth:`torch.sort`.  See its documentation
+for the exact semantics of this method.
+
+Args:
+    input (Tensor): the input tensor
+    dim (int, optional): the dimension to sort along
+    descending (bool, optional): controls the sorting order (ascending or descending)
+
+Example::
+
+    >>> a = torch.randn(4, 4)
+    >>> a
+    tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
+            [ 0.1598,  0.0788, -0.0745, -1.2700],
+            [ 1.2208,  1.0722, -0.7064,  1.2564],
+            [ 0.0669, -0.2318, -0.8229, -0.9280]])
+
+
+    >>> torch.argsort(a, dim=1)
+    tensor([[2, 0, 3, 1],
+            [3, 2, 1, 0],
+            [2, 1, 0, 3],
+            [3, 2, 1, 0]])
+""")
+
 add_docstr(torch.sparse_coo_tensor,
            r"""
 sparse_coo_tensor(indices, values, size=None, dtype=None, device=None, requires_grad=False) -> Tensor
index 8942cee..f66c389 100644 (file)
@@ -13,7 +13,6 @@ import warnings
 __all__ = [
     'argmax',
     'argmin',
-    'argsort',
     'btriunpack',
     'chain_matmul',
     'einsum',
@@ -570,39 +569,6 @@ def tensordot(a, b, dims=2):
     return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)
 
 
-def argsort(input, dim=None, descending=False):
-    r"""Returns the indices that sort a tensor along a given dimension in ascending
-    order by value.
-
-    This is the second value returned by :meth:`torch.sort`.  See its documentation
-    for the exact semantics of this method.
-
-    Args:
-        input (Tensor): the input tensor
-        dim (int, optional): the dimension to sort along
-        descending (bool, optional): controls the sorting order (ascending or descending)
-
-    Example::
-
-        >>> a = torch.randn(4, 4)
-        >>> a
-        tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
-                [ 0.1598,  0.0788, -0.0745, -1.2700],
-                [ 1.2208,  1.0722, -0.7064,  1.2564],
-                [ 0.0669, -0.2318, -0.8229, -0.9280]])
-
-
-        >>> torch.argsort(a, dim=1)
-        tensor([[2, 0, 3, 1],
-                [3, 2, 1, 0],
-                [2, 1, 0, 3],
-                [3, 2, 1, 0]])
-    """
-    if dim is None:
-        return torch.sort(input, -1, descending)[1]
-    return torch.sort(input, dim, descending)[1]
-
-
 def cartesian_prod(*tensors):
     """Do cartesian product of the given sequence of tensors. The behavior is similar to
     python's `itertools.product`.
index 2e1322c..bc4be2b 100644 (file)
@@ -255,10 +255,6 @@ class Tensor(torch._C._TensorBase):
         r"""See :func:`torch.argmin`"""
         return torch.argmin(self, dim, keepdim)
 
-    def argsort(self, dim=None, descending=False):
-        r"""See :func:`torch.argsort`"""
-        return torch.argsort(self, dim, descending)
-
     def norm(self, p="fro", dim=None, keepdim=False, dtype=None):
         r"""See :func:`torch.norm`"""
         return torch.norm(self, p, dim, keepdim, dtype=dtype)