From: Xiang Gao Date: Tue, 26 Mar 2019 19:33:09 +0000 (-0700) Subject: Namedtuple return for solve, slogdet, sort, topk (#17093) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~618 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5bff395a821851a0b720f178165488d3a815311d;p=platform%2Fupstream%2Fpytorch.git Namedtuple return for solve, slogdet, sort, topk (#17093) Summary: More ops for https://github.com/pytorch/pytorch/issues/394. ~~Also need to rebase after landing #16186, because we need to update the whitelist of the new unit test added in #16186.~~ cc: ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/17093 Differential Revision: D14620068 Pulled By: ezyang fbshipit-source-id: deec5ffc9bf7624e0350c85392ee59789bad4237 --- diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0793434..0d4da52 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1990,7 +1990,7 @@ variants: function, method device_guard: False -- func: slogdet(Tensor self) -> (Tensor, Tensor) +- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) matches_jit_signature: True variants: function, method @@ -3762,11 +3762,11 @@ CPU: _cholesky_solve_helper_cpu CUDA: _cholesky_solve_helper_cuda -- func: solve(Tensor self, Tensor A) -> (Tensor, Tensor) +- func: solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU) matches_jit_signature: True variants: function, method -- func: solve(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!), Tensor(b!)) +- func: solve(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!) solution, Tensor(b!) LU) matches_jit_signature: True - func: _solve_helper(Tensor self, Tensor A) -> (Tensor, Tensor) @@ -4016,10 +4016,10 @@ CPU: median_cpu CUDA: median_cuda -- func: sort(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +- func: sort(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) matches_jit_signature: True -- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor, Tensor) +- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) matches_jit_signature: True variants: method, function @@ -4027,10 +4027,10 @@ matches_jit_signature: True variants: method, function -- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices) matches_jit_signature: True -- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor, Tensor) +- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) matches_jit_signature: True variants: method, function diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index d774b8b..c6a0c49 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -8,7 +8,7 @@ path = os.path.dirname(os.path.realpath(__file__)) aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml') whitelist = [ 'max', 'min', 'median', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', - 'pstrf', 'qr', 'geqrf', + 'pstrf', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', ] diff --git a/test/test_torch.py b/test/test_torch.py index 5728dd8..4d9d619 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7411,6 +7411,41 @@ class _TestTorchMixin(object): self.assertEqual(ret1.S, ret[1]) self.assertEqual(ret1.V, ret[2]) + # test gesv + ret = a.gesv(a) + self.assertEqual(ret.solution, ret[0]) + self.assertEqual(ret.LU, ret[1]) + ret1 = torch.gesv(a, a, out=tuple(ret)) + self.assertEqual(ret1.solution, ret1[0]) + self.assertEqual(ret1.LU, ret1[1]) + self.assertEqual(ret1.solution, ret[0]) + self.assertEqual(ret1.LU, ret[1]) + + # test slogdet + ret = a.slogdet() + self.assertEqual(ret.sign, ret[0]) + self.assertEqual(ret.logabsdet, ret[1]) + + # test sort + ret = a.sort(dim=0) + self.assertEqual(ret.values, ret[0]) + self.assertEqual(ret.indices, ret[1]) + ret1 = torch.sort(a, dim=0, out=tuple(ret)) + self.assertEqual(ret1.values, ret1[0]) + self.assertEqual(ret1.indices, ret1[1]) + self.assertEqual(ret1.values, ret[0]) + self.assertEqual(ret1.indices, ret[1]) + + # test topk + ret = a.topk(2) + self.assertEqual(ret.values, ret[0]) + self.assertEqual(ret.indices, ret[1]) + ret1 = torch.topk(a, 2, out=tuple(ret)) + self.assertEqual(ret1.values, ret1[0]) + self.assertEqual(ret1.indices, ret1[1]) + self.assertEqual(ret1.values, ret[0]) + self.assertEqual(ret1.indices, ret[1]) + # test symeig, eig fn = ['symeig', 'eig'] for f in fn: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index b9a10d6..0a460a2 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -745,15 +745,15 @@ self: slice_backward(grad, self.sizes(), dim, start, end, step) - name: slogdet(Tensor self) - self: slogdet_backward(grad, self, result0, result1) + self: slogdet_backward(grad, self, sign, logabsdet) output_differentiability: [false, true] - name: solve(Tensor self, Tensor A) self: solve_backward_self(grad, self, A) - A: solve_backward_A(grad, self, A, result0) + A: solve_backward_A(grad, self, A, solution) - name: sort(Tensor self, int64_t dim, bool descending) - self: index_select_backward(grad, dim, result1, self.sizes(), true) + self: index_select_backward(grad, dim, indices, self.sizes(), true) - name: split(Tensor self, int64_t split_size, int64_t dim) self: split_backward(grads, split_size, dim, self.sizes(), self.type()) @@ -839,7 +839,7 @@ self: tanh_backward(grad, result) - name: topk(Tensor self, int64_t k, int64_t dim, bool largest, bool sorted) - self: index_select_backward(grad, dim, result1, self.sizes(), true) + self: index_select_backward(grad, dim, indices, self.sizes(), true) - name: trace(Tensor self) self: trace_backward(grad, self.sizes()) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 0fd3d5e..3249126 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2024,18 +2024,18 @@ torch.solve(B, A, out=None) -> (Tensor, Tensor) This function returns the solution to the system of linear equations represented by :math:`AX = B` and the LU factorization of -A, in order as a tuple `X, LU`. +A, in order as a namedtuple `solution, LU`. `LU` contains `L` and `U` factors for LU factorization of `A`. `torch.solve(B, A)` can take in 2D inputs `B, A` or inputs that are batches of 2D matrices. If the inputs are batches, then returns -batched outputs `X, LU`. +batched outputs `solution, LU`. .. note:: Irrespective of the original strides, the returned matrices - `X` and `LU` will be transposed, i.e. with strides like + `solution` and `LU` will be transposed, i.e. with strides like `B.contiguous().transpose(-1, -2).strides()` and `A.contiguous().transpose(-1, -2).strides()` respectively. @@ -4353,8 +4353,9 @@ If :attr:`dim` is not given, the last dimension of the `input` is chosen. If :attr:`descending` is ``True`` then the elements are sorted in descending order by value. -A tuple of (sorted_tensor, sorted_indices) is returned, where the -sorted_indices are the indices of the elements in the original `input` tensor. +A namedtuple of (values, indices) is returned, where the `values` are the +sorted values and `indices` are the indices of the elements in the original +`input` tensor. Args: input (Tensor): the input tensor @@ -5018,7 +5019,7 @@ If :attr:`dim` is not given, the last dimension of the `input` is chosen. If :attr:`largest` is ``False`` then the `k` smallest elements are returned. -A tuple of `(values, indices)` is returned, where the `indices` are the indices +A namedtuple of `(values, indices)` is returned, where the `indices` are the indices of the elements in the original `input` tensor. The boolean option :attr:`sorted` if ``True``, will make sure that the returned @@ -5041,7 +5042,7 @@ Example:: >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.topk(x, 3) - (tensor([ 5., 4., 3.]), tensor([ 4, 3, 2])) + torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) """) add_docstr(torch.trace, @@ -5800,18 +5801,22 @@ Arguments: A (Tensor): The input 2D square tensor Returns: - A tuple containing the sign of the determinant, and the log value of the - absolute determinant. + A namedtuple (sign, logabsdet) containing the sign of the determinant, and the log + value of the absolute determinant. Example:: >>> A = torch.randn(3, 3) + >>> A + tensor([[ 0.0032, -0.2239, -1.1219], + [-0.6690, 0.1161, 0.4053], + [-1.6218, -0.9273, -0.0082]]) >>> torch.det(A) - tensor(-4.8215) + tensor(-0.7576) >>> torch.logdet(A) tensor(nan) >>> torch.slogdet(A) - (tensor(-1.), tensor(1.5731)) + torch.return_types.slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776)) """) add_docstr(torch.pinverse,