Namedtuple return for solve, slogdet, sort, topk (#17093)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Tue, 26 Mar 2019 19:33:09 +0000 (12:33 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 19:39:08 +0000 (12:39 -0700)
Summary:
More ops for https://github.com/pytorch/pytorch/issues/394. ~~Also need to rebase after landing #16186, because we need to update the whitelist of the new unit test added in #16186.~~

cc: ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17093

Differential Revision: D14620068

Pulled By: ezyang

fbshipit-source-id: deec5ffc9bf7624e0350c85392ee59789bad4237

aten/src/ATen/native/native_functions.yaml
test/test_namedtuple_return_api.py
test/test_torch.py
tools/autograd/derivatives.yaml
torch/_torch_docs.py

index 0793434..0d4da52 100644 (file)
   variants: function, method
   device_guard: False
 
-- func: slogdet(Tensor self) -> (Tensor, Tensor)
+- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
   matches_jit_signature: True
   variants: function, method
 
     CPU: _cholesky_solve_helper_cpu
     CUDA: _cholesky_solve_helper_cuda
 
-- func: solve(Tensor self, Tensor A) -> (Tensor, Tensor)
+- func: solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU)
   matches_jit_signature: True
   variants: function, method
 
-- func: solve(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!), Tensor(b!))
+- func: solve(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!) solution, Tensor(b!) LU)
   matches_jit_signature: True
 
 - func: _solve_helper(Tensor self, Tensor A) -> (Tensor, Tensor)
     CPU: median_cpu
     CUDA: median_cuda
 
-- func: sort(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+- func: sort(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
   matches_jit_signature: True
 
-- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor, Tensor)
+- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
   matches_jit_signature: True
   variants: method, function
 
   matches_jit_signature: True
   variants: method, function
 
-- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices)
   matches_jit_signature: True
 
-- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor, Tensor)
+- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
   matches_jit_signature: True
   variants: method, function
 
index d774b8b..c6a0c49 100644 (file)
@@ -8,7 +8,7 @@ path = os.path.dirname(os.path.realpath(__file__))
 aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml')
 whitelist = [
     'max', 'min', 'median', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
-    'pstrf', 'qr', 'geqrf',
+    'pstrf', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk',
 ]
 
 
index 5728dd8..4d9d619 100644 (file)
@@ -7411,6 +7411,41 @@ class _TestTorchMixin(object):
         self.assertEqual(ret1.S, ret[1])
         self.assertEqual(ret1.V, ret[2])
 
+        # test gesv
+        ret = a.gesv(a)
+        self.assertEqual(ret.solution, ret[0])
+        self.assertEqual(ret.LU, ret[1])
+        ret1 = torch.gesv(a, a, out=tuple(ret))
+        self.assertEqual(ret1.solution, ret1[0])
+        self.assertEqual(ret1.LU, ret1[1])
+        self.assertEqual(ret1.solution, ret[0])
+        self.assertEqual(ret1.LU, ret[1])
+
+        # test slogdet
+        ret = a.slogdet()
+        self.assertEqual(ret.sign, ret[0])
+        self.assertEqual(ret.logabsdet, ret[1])
+
+        # test sort
+        ret = a.sort(dim=0)
+        self.assertEqual(ret.values, ret[0])
+        self.assertEqual(ret.indices, ret[1])
+        ret1 = torch.sort(a, dim=0, out=tuple(ret))
+        self.assertEqual(ret1.values, ret1[0])
+        self.assertEqual(ret1.indices, ret1[1])
+        self.assertEqual(ret1.values, ret[0])
+        self.assertEqual(ret1.indices, ret[1])
+
+        # test topk
+        ret = a.topk(2)
+        self.assertEqual(ret.values, ret[0])
+        self.assertEqual(ret.indices, ret[1])
+        ret1 = torch.topk(a, 2, out=tuple(ret))
+        self.assertEqual(ret1.values, ret1[0])
+        self.assertEqual(ret1.indices, ret1[1])
+        self.assertEqual(ret1.values, ret[0])
+        self.assertEqual(ret1.indices, ret[1])
+
         # test symeig, eig
         fn = ['symeig', 'eig']
         for f in fn:
index b9a10d6..0a460a2 100644 (file)
   self: slice_backward(grad, self.sizes(), dim, start, end, step)
 
 - name: slogdet(Tensor self)
-  self: slogdet_backward(grad, self, result0, result1)
+  self: slogdet_backward(grad, self, sign, logabsdet)
   output_differentiability: [false, true]
 
 - name: solve(Tensor self, Tensor A)
   self: solve_backward_self(grad, self, A)
-  A: solve_backward_A(grad, self, A, result0)
+  A: solve_backward_A(grad, self, A, solution)
 
 - name: sort(Tensor self, int64_t dim, bool descending)
-  self: index_select_backward(grad, dim, result1, self.sizes(), true)
+  self: index_select_backward(grad, dim, indices, self.sizes(), true)
 
 - name: split(Tensor self, int64_t split_size, int64_t dim)
   self: split_backward(grads, split_size, dim, self.sizes(), self.type())
   self: tanh_backward(grad, result)
 
 - name: topk(Tensor self, int64_t k, int64_t dim, bool largest, bool sorted)
-  self: index_select_backward(grad, dim, result1, self.sizes(), true)
+  self: index_select_backward(grad, dim, indices, self.sizes(), true)
 
 - name: trace(Tensor self)
   self: trace_backward(grad, self.sizes())
index 0fd3d5e..3249126 100644 (file)
@@ -2024,18 +2024,18 @@ torch.solve(B, A, out=None) -> (Tensor, Tensor)
 
 This function returns the solution to the system of linear
 equations represented by :math:`AX = B` and the LU factorization of
-A, in order as a tuple `X, LU`.
+A, in order as a namedtuple `solution, LU`.
 
 `LU` contains `L` and `U` factors for LU factorization of `A`.
 
 `torch.solve(B, A)` can take in 2D inputs `B, A` or inputs that are
 batches of 2D matrices. If the inputs are batches, then returns
-batched outputs `X, LU`.
+batched outputs `solution, LU`.
 
 .. note::
 
     Irrespective of the original strides, the returned matrices
-    `X` and `LU` will be transposed, i.e. with strides like
+    `solution` and `LU` will be transposed, i.e. with strides like
     `B.contiguous().transpose(-1, -2).strides()` and
     `A.contiguous().transpose(-1, -2).strides()` respectively.
 
@@ -4353,8 +4353,9 @@ If :attr:`dim` is not given, the last dimension of the `input` is chosen.
 If :attr:`descending` is ``True`` then the elements are sorted in descending
 order by value.
 
-A tuple of (sorted_tensor, sorted_indices) is returned, where the
-sorted_indices are the indices of the elements in the original `input` tensor.
+A namedtuple of (values, indices) is returned, where the `values` are the
+sorted values and `indices` are the indices of the elements in the original
+`input` tensor.
 
 Args:
     input (Tensor): the input tensor
@@ -5018,7 +5019,7 @@ If :attr:`dim` is not given, the last dimension of the `input` is chosen.
 
 If :attr:`largest` is ``False`` then the `k` smallest elements are returned.
 
-A tuple of `(values, indices)` is returned, where the `indices` are the indices
+A namedtuple of `(values, indices)` is returned, where the `indices` are the indices
 of the elements in the original `input` tensor.
 
 The boolean option :attr:`sorted` if ``True``, will make sure that the returned
@@ -5041,7 +5042,7 @@ Example::
     >>> x
     tensor([ 1.,  2.,  3.,  4.,  5.])
     >>> torch.topk(x, 3)
-    (tensor([ 5.,  4.,  3.]), tensor([ 4,  3,  2]))
+    torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
 """)
 
 add_docstr(torch.trace,
@@ -5800,18 +5801,22 @@ Arguments:
     A (Tensor): The input 2D square tensor
 
 Returns:
-    A tuple containing the sign of the determinant, and the log value of the
-    absolute determinant.
+    A namedtuple (sign, logabsdet) containing the sign of the determinant, and the log
+    value of the absolute determinant.
 
 Example::
 
     >>> A = torch.randn(3, 3)
+    >>> A
+    tensor([[ 0.0032, -0.2239, -1.1219],
+            [-0.6690,  0.1161,  0.4053],
+            [-1.6218, -0.9273, -0.0082]])
     >>> torch.det(A)
-    tensor(-4.8215)
+    tensor(-0.7576)
     >>> torch.logdet(A)
     tensor(nan)
     >>> torch.slogdet(A)
-    (tensor(-1.), tensor(1.5731))
+    torch.return_types.slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776))
 """)
 
 add_docstr(torch.pinverse,