From: Gao, Xiang Date: Thu, 21 Feb 2019 15:50:27 +0000 (-0800) Subject: Move argsort to C++ X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1166 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=722cbe3064d4a4d13f104c4fee9ca11e18f6f21c;p=platform%2Fupstream%2Fpytorch.git Move argsort to C++ Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17099 Differential Revision: D14165671 Pulled By: ezyang fbshipit-source-id: 3871de6874fe09871ebd9b8943c13c9af325bf33 --- diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 6c337fd..0682f69 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -706,6 +706,7 @@ class CAFFE2_API Tensor { Tensor max() const; Tensor median() const; std::tuple sort(int64_t dim=-1, bool descending=false) const; + Tensor argsort(int64_t dim=-1, bool descending=false) const; std::tuple topk(int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) const; Tensor all() const; Tensor any() const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 973d905..026b2c3 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -1252,6 +1252,9 @@ inline Tensor Tensor::median() const { inline std::tuple Tensor::sort(int64_t dim, bool descending) const { return type().sort(*this, dim, descending); } +inline Tensor Tensor::argsort(int64_t dim, bool descending) const { + return type().argsort(*this, dim, descending); +} inline std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const { return type().topk(*this, k, dim, largest, sorted); } diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 1b300f1..825081f 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -590,6 +590,7 @@ struct CAFFE2_API Type { virtual Tensor max(const Tensor & self) const = 0; virtual Tensor median(const Tensor & self) const = 0; virtual std::tuple sort(const Tensor & self, int64_t dim, bool descending) const = 0; + virtual Tensor argsort(const Tensor & self, int64_t dim, bool descending) const = 0; virtual std::tuple topk(const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) const = 0; virtual Tensor all(const Tensor & self) const = 0; virtual Tensor any(const Tensor & self) const = 0; diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index b2cd6db..329680d 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -697,6 +697,11 @@ std::tuple sort_out(Tensor & values, Tensor & indices, const std::tuple sort(const Tensor & self, int64_t dim, bool descending) { return at::legacy::th::_th_sort(self, dim, descending); } + +Tensor argsort(const Tensor & self, int64_t dim, bool descending) { + return std::get<1>(at::legacy::th::_th_sort(self, dim, descending)); +} + std::tuple topk_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) { return at::legacy::th::_th_topk_out(values, indices, self, k, dim, largest, sorted); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1780ea4..b79ff46 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3856,6 +3856,10 @@ matches_jit_signature: True variants: method, function +- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor + matches_jit_signature: True + variants: method, function + - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!), Tensor(b!)) - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor, Tensor) diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in index aa3c1f0..d38fd74 100644 --- a/torch/__init__.pyi.in +++ b/torch/__init__.pyi.in @@ -75,7 +75,6 @@ class Tensor: # way to not have to write these out again... def argmax(self, dim=None, keepdim=False): ... def argmin(self, dim=None, keepdim=False): ... - def argsort(self, dim=None, descending=False): ... def norm(self, p="fro", dim=None, keepdim=False): ... def stft(self, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): ... diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index bd6a009..4870343 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2206,11 +2206,18 @@ Example:: add_docstr_all('sort', r""" -sort(dim=None, descending=False) -> (Tensor, LongTensor) +sort(dim=-1, descending=False) -> (Tensor, LongTensor) See :func:`torch.sort` """) +add_docstr_all('argsort', + r""" +argsort(dim=-1, descending=False) -> LongTensor + +See :func: `torch.argsort` +""") + add_docstr_all('sparse_dim', r""" sparse_dim() -> int diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 188f842..f5906c1 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4269,7 +4269,7 @@ Example:: add_docstr(torch.sort, r""" -sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor) +sort(input, dim=-1, descending=False, out=None) -> (Tensor, LongTensor) Sorts the elements of the :attr:`input` tensor along a given dimension in ascending order by value. @@ -4313,6 +4313,38 @@ Example:: [ 1, 2, 2, 0]]) """) +add_docstr(torch.argsort, + r""" +argsort(input, dim=-1, descending=False, out=None) -> LongTensor + +Returns the indices that sort a tensor along a given dimension in ascending +order by value. + +This is the second value returned by :meth:`torch.sort`. See its documentation +for the exact semantics of this method. + +Args: + input (Tensor): the input tensor + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) +""") + add_docstr(torch.sparse_coo_tensor, r""" sparse_coo_tensor(indices, values, size=None, dtype=None, device=None, requires_grad=False) -> Tensor diff --git a/torch/functional.py b/torch/functional.py index 8942cee..f66c389 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -13,7 +13,6 @@ import warnings __all__ = [ 'argmax', 'argmin', - 'argsort', 'btriunpack', 'chain_matmul', 'einsum', @@ -570,39 +569,6 @@ def tensordot(a, b, dims=2): return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b) -def argsort(input, dim=None, descending=False): - r"""Returns the indices that sort a tensor along a given dimension in ascending - order by value. - - This is the second value returned by :meth:`torch.sort`. See its documentation - for the exact semantics of this method. - - Args: - input (Tensor): the input tensor - dim (int, optional): the dimension to sort along - descending (bool, optional): controls the sorting order (ascending or descending) - - Example:: - - >>> a = torch.randn(4, 4) - >>> a - tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], - [ 0.1598, 0.0788, -0.0745, -1.2700], - [ 1.2208, 1.0722, -0.7064, 1.2564], - [ 0.0669, -0.2318, -0.8229, -0.9280]]) - - - >>> torch.argsort(a, dim=1) - tensor([[2, 0, 3, 1], - [3, 2, 1, 0], - [2, 1, 0, 3], - [3, 2, 1, 0]]) - """ - if dim is None: - return torch.sort(input, -1, descending)[1] - return torch.sort(input, dim, descending)[1] - - def cartesian_prod(*tensors): """Do cartesian product of the given sequence of tensors. The behavior is similar to python's `itertools.product`. diff --git a/torch/tensor.py b/torch/tensor.py index 2e1322c..bc4be2b 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -255,10 +255,6 @@ class Tensor(torch._C._TensorBase): r"""See :func:`torch.argmin`""" return torch.argmin(self, dim, keepdim) - def argsort(self, dim=None, descending=False): - r"""See :func:`torch.argsort`""" - return torch.argsort(self, dim, descending) - def norm(self, p="fro", dim=None, keepdim=False, dtype=None): r"""See :func:`torch.norm`""" return torch.norm(self, p, dim, keepdim, dtype=dtype)