Namedtuple return for symeig, eig, pstrf, qr, geqrf (#16950)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Wed, 20 Feb 2019 21:47:50 +0000 (13:47 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 20 Feb 2019 22:01:19 +0000 (14:01 -0800)
Summary: More ops for https://github.com/pytorch/pytorch/issues/394

Differential Revision: D14118645

Pulled By: ezyang

fbshipit-source-id: a98646c3ddcbe4e34452aa044951286dcf9df778

aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native_parse.py
test/test_namedtuple_return_api.py
test/test_torch.py
tools/autograd/derivatives.yaml
torch/_torch_docs.py

index 8ada134..a2f0984 100644 (file)
   matches_jit_signature: True
   variants: method, function
 
-- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) ->(Tensor(a!), Tensor(b!))
+- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) ->(Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
 
-- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor, Tensor)
+- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
   matches_jit_signature: True
   variants: method, function
 
-- func: eig(Tensor self, bool eigenvectors=False, *, Tensor(a!) e, Tensor(b!) v) ->(Tensor(a!), Tensor(b!))
+- func: eig(Tensor self, bool eigenvectors=False, *, Tensor(a!) e, Tensor(b!) v) ->(Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
 
-- func: eig(Tensor self, bool eigenvectors=False) -> (Tensor, Tensor)
+- func: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)
   matches_jit_signature: True
   variants: method, function
 
   matches_jit_signature: True
   variants: method, function
 
-- func: pstrf(Tensor self, bool upper=True, Scalar tol=-1, *, Tensor(a!) u, Tensor(b!) piv) ->(Tensor(a!), Tensor(b!))
+- func: pstrf(Tensor self, bool upper=True, Scalar tol=-1, *, Tensor(a!) u, Tensor(b!) pivot) ->(Tensor(a!) u, Tensor(b!) pivot)
 
-- func: pstrf(Tensor self, bool upper=True, Scalar tol=-1) -> (Tensor, Tensor)
+- func: pstrf(Tensor self, bool upper=True, Scalar tol=-1) -> (Tensor u, Tensor pivot)
   matches_jit_signature: True
   variants: method, function
 
-- func: qr(Tensor self, *, Tensor(a!) Q, Tensor(b!) R) ->(Tensor(a!), Tensor(b!))
+- func: qr(Tensor self, *, Tensor(a!) Q, Tensor(b!) R) ->(Tensor(a!) Q, Tensor(b!) R)
 
-- func: qr(Tensor self) -> (Tensor, Tensor)
+- func: qr(Tensor self) -> (Tensor Q, Tensor R)
   matches_jit_signature: True
   variants: method, function
 
-- func: geqrf(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) ->(Tensor(a!), Tensor(b!))
+- func: geqrf(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) ->(Tensor(a!) a, Tensor(b!) tau)
 
-- func: geqrf(Tensor self) -> (Tensor, Tensor)
+- func: geqrf(Tensor self) -> (Tensor a, Tensor tau)
   matches_jit_signature: True
   variants: method, function
 
index 21555d7..5b31119 100644 (file)
@@ -209,7 +209,8 @@ def parse_arguments(args, func_variants, declaration, func_return):
                 assert argument['annotation'] == func_return[arg_idx]['annotation'], \
                     "Inplace function annotations of function {} need to match between " \
                     "input and correponding output.".format(name)
-                assert argument['name'] == func_return[arg_idx]['name']
+                assert argument['name'] == func_return[arg_idx]['name'] or \
+                    argument['name'] == func_return[arg_idx]['name'] + "_return"
                 assert argument['type'] == func_return[arg_idx]['type']
         assert found_self, "Inplace function \"{}\" needs Tensor argument named self.".format(name)
 
@@ -226,7 +227,12 @@ def parse_return_arguments(return_decl, inplace, func_decl):
 
     for arg_idx, arg in enumerate(return_decl.split(', ')):
         t, name, default, nullable, size, annotation = type_argument_translations(arg)
-        argument_dict = {'type': t, 'name': name, 'annotation': annotation}
+        # name of arguments and name of return sometimes have collision
+        # in this case, we rename the return name to <name>_return.
+        return_name = name
+        if name in func_decl['func'].split('->')[0]:
+            return_name = name + "_return"
+        argument_dict = {'type': t, 'name': return_name, 'annotation': annotation}
         if name:
             # See Note [field_name versus name]
             argument_dict['field_name'] = name
index 9ca22c8..d774b8b 100644 (file)
@@ -7,8 +7,8 @@ import unittest
 path = os.path.dirname(os.path.realpath(__file__))
 aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml')
 whitelist = [
-    'max', 'max_out', 'min', 'min_out', 'median', 'median_out',
-    'mode', 'mode_out', 'kthvalue', 'kthvalue_out', 'svd', 'svd_out',
+    'max', 'min', 'median', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
+    'pstrf', 'qr', 'geqrf',
 ]
 
 
index 3d95d82..058941d 100644 (file)
@@ -7141,6 +7141,48 @@ class _TestTorchMixin(object):
         self.assertEqual(ret1.S, ret[1])
         self.assertEqual(ret1.V, ret[2])
 
+        # test symeig, eig
+        fn = ['symeig', 'eig']
+        for f in fn:
+            ret = getattr(torch, f)(a, eigenvectors=True)
+            self.assertEqual(ret.eigenvalues, ret[0])
+            self.assertEqual(ret.eigenvectors, ret[1])
+            ret1 = getattr(torch, f)(a, out=tuple(ret))
+            self.assertEqual(ret1.eigenvalues, ret[0])
+            self.assertEqual(ret1.eigenvectors, ret[1])
+            self.assertEqual(ret1.eigenvalues, ret1[0])
+            self.assertEqual(ret1.eigenvectors, ret1[1])
+
+        # test pstrf
+        b = torch.mm(a, a.t())
+        # add a small number to the diagonal to make the matrix numerically positive semidefinite
+        for i in range(a.size(0)):
+            b[i][i] = b[i][i] + 1e-7
+        ret = b.pstrf()
+        self.assertEqual(ret.u, ret[0])
+        self.assertEqual(ret.pivot, ret[1])
+        ret1 = torch.pstrf(b, out=tuple(ret))
+        self.assertEqual(ret1.u, ret1[0])
+        self.assertEqual(ret1.pivot, ret1[1])
+        self.assertEqual(ret1.u, ret[0])
+        self.assertEqual(ret1.pivot, ret[1])
+
+        # test qr
+        ret = a.qr()
+        self.assertEqual(ret.Q, ret[0])
+        self.assertEqual(ret.R, ret[1])
+        ret1 = torch.qr(a, out=tuple(ret))
+        self.assertEqual(ret1.Q, ret1[0])
+        self.assertEqual(ret1.R, ret1[1])
+
+        # test geqrf
+        ret = a.geqrf()
+        self.assertEqual(ret.a, ret[0])
+        self.assertEqual(ret.tau, ret[1])
+        ret1 = torch.geqrf(a, out=tuple(ret))
+        self.assertEqual(ret1.a, ret1[0])
+        self.assertEqual(ret1.tau, ret1[1])
+
     def test_hardshrink(self):
         data_original = torch.tensor([1, 0.5, 0.3, 0.6]).view(2, 2)
         float_types = [
index 107c0d7..d690119 100644 (file)
   self: svd_backward(grads, self, some, compute_uv, U, S, V)
 
 - name: symeig(Tensor self, bool eigenvectors, bool upper)
-  self: symeig_backward(grads, self, eigenvectors, upper, result0, result1)
+  self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return)
 
 - name: t(Tensor self)
   self: grad.t()
index cc30354..188f842 100644 (file)
@@ -1555,20 +1555,20 @@ Args:
     out (tuple, optional): the output tensors
 
 Returns:
-    (Tensor, Tensor): A tuple containing
+    (Tensor, Tensor): A namedtuple (eigenvalues, eigenvectors) containing
 
-        - **e** (*Tensor*): Shape :math:`(n \times 2)`. Each row is an eigenvalue of ``a``,
+        - **eigenvalues** (*Tensor*): Shape :math:`(n \times 2)`. Each row is an eigenvalue of ``a``,
           where the first element is the real part and the second element is the imaginary part.
           The eigenvalues are not necessarily ordered.
-        - **v** (*Tensor*): If ``eigenvectors=False``, it's an empty tensor.
+        - **eigenvectors** (*Tensor*): If ``eigenvectors=False``, it's an empty tensor.
           Otherwise, this tensor of shape :math:`(n \times n)` can be used to compute normalized (unit length)
-          eigenvectors of corresponding eigenvalues ``e`` as follows.
-          If the corresponding e[j] is a real number, column v[:, j] is the eigenvector corresponding to
-          eigenvalue e[j].
-          If the corresponding e[j] and e[j + 1] eigenvalues form a complex conjugate pair, then the true eigenvectors
-          can be computed as
-          :math:`\text{eigenvector}[j] = v[:, j] + i \times v[:, j + 1]`,
-          :math:`\text{eigenvector}[j + 1] = v[:, j] - i \times v[:, j + 1]`.
+          eigenvectors of corresponding eigenvalues as follows.
+          If the corresponding `eigenvalues[j]` is a real number, column `eigenvectors[:, j]` is the eigenvector
+          corresponding to `eigenvalues[j]`.
+          If the corresponding `eigenvalues[j]` and `eigenvalues[j + 1]` form a complex conjugate pair, then the
+          true eigenvectors can be computed as
+          :math:`\text{true eigenvector}[j] = eigenvectors[:, j] + i \times eigenvectors[:, j + 1]`,
+          :math:`\text{true eigenvector}[j + 1] = eigenvectors[:, j] - i \times eigenvectors[:, j + 1]`.
 """)
 
 add_docstr(torch.eq,
@@ -1969,7 +1969,8 @@ add_docstr(torch.geqrf,
            r"""
 geqrf(input, out=None) -> (Tensor, Tensor)
 
-This is a low-level function for calling LAPACK directly.
+This is a low-level function for calling LAPACK directly. This function
+returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ .
 
 You'll generally want to use :func:`torch.qr` instead.
 
@@ -3627,10 +3628,10 @@ add_docstr(torch.pstrf, r"""
 pstrf(a, upper=True, out=None) -> (Tensor, Tensor)
 
 Computes the pivoted Cholesky decomposition of a positive semidefinite
-matrix :attr:`a`. returns matrices `u` and `piv`.
+matrix :attr:`a`. returns a namedtuple (u, pivot) of matrice.
 
 If :attr:`upper` is ``True`` or not provided, `u` is upper triangular
-such that :math:`a = p^T u^T u p`, with `p` the permutation given by `piv`.
+such that :math:`a = p^T u^T u p`, with `p` the permutation given by `pivot`.
 
 If :attr:`upper` is ``False``, `u` is lower triangular such that
 :math:`a = p^T u u^T p`.
@@ -3638,7 +3639,7 @@ If :attr:`upper` is ``False``, `u` is lower triangular such that
 Args:
     a (Tensor): the input 2-D tensor
     upper (bool, optional): whether to return a upper (default) or lower triangular matrix
-    out (tuple, optional): tuple of `u` and `piv` tensors
+    out (tuple, optional): namedtuple of `u` and `pivot` tensors
 
 Example::
 
@@ -3666,8 +3667,8 @@ add_docstr(torch.qr,
            r"""
 qr(input, out=None) -> (Tensor, Tensor)
 
-Computes the QR decomposition of a matrix :attr:`input`, and returns matrices
-`Q` and `R` such that :math:`\text{input} = Q R`, with :math:`Q` being an
+Computes the QR decomposition of a matrix :attr:`input`, and returns a namedtuple
+(Q, R) of matrices such that :math:`\text{input} = Q R`, with :math:`Q` being an
 orthogonal matrix and :math:`R` being an upper triangular matrix.
 
 This returns the thin (reduced) QR factorization.
@@ -4633,7 +4634,8 @@ add_docstr(torch.symeig,
 symeig(input, eigenvectors=False, upper=True, out=None) -> (Tensor, Tensor)
 
 This function returns eigenvalues and eigenvectors
-of a real symmetric matrix :attr:`input`, represented by a tuple :math:`(e, V)`.
+of a real symmetric matrix :attr:`input`, represented by a namedtuple
+(eigenvalues, eigenvectors).
 
 :attr:`input` and :math:`V` are :math:`(m \times m)` matrices and :math:`e` is a
 :math:`m` dimensional vector.
@@ -4666,11 +4668,11 @@ Args:
     out (tuple, optional): the output tuple of (Tensor, Tensor)
 
 Returns:
-    (Tensor, Tensor): A tuple containing
+    (Tensor, Tensor): A namedtuple (eigenvalues, eigenvectors) containing
 
-        - **e** (*Tensor*): Shape :math:`(m)`. Each element is an eigenvalue of ``input``,
+        - **eigenvalues** (*Tensor*): Shape :math:`(m)`. Each element is an eigenvalue of ``input``,
           The eigenvalues are in ascending order.
-        - **V** (*Tensor*): Shape :math:`(m \times m)`.
+        - **eigenvectors** (*Tensor*): Shape :math:`(m \times m)`.
           If ``eigenvectors=False``, it's a tensor filled with zeros.
           Otherwise, this tensor contains the orthonormal eigenvectors of the ``input``.