Implements the orthogonal parametrization (#62089)
authorlezcano <lezcano-93@hotmail.com>
Mon, 30 Aug 2021 20:10:23 +0000 (13:10 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 20:12:07 +0000 (13:12 -0700)
Summary:
Implements an orthogonal / unitary parametrisation.

It does passes the tests and I have trained a couple models with this implementation, so I believe it should be somewhat correct. Now, the implementation is very subtle. I'm tagging nikitaved  and IvanYashchuk as reviewers in case they have comments / they see some room for optimisation of the code, in particular of the `forward` function.

Fixes https://github.com/pytorch/pytorch/issues/42243

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62089

Reviewed By: ezyang

Differential Revision: D30639063

Pulled By: albanD

fbshipit-source-id: 988664f333ac7a75ce71ba44c8d77b986dff2fe6

docs/source/nn.rst
test/test_nn.py
torch/nn/utils/parametrizations.py
torch/nn/utils/parametrize.py

index 07ce4db..6eca9d4 100644 (file)
@@ -389,6 +389,7 @@ in :func:`torch.nn.utils.parameterize.register_parametrization`.
     :toctree: generated
     :nosignatures:
 
+    parametrizations.orthogonal
     parametrizations.spectral_norm
 
 Utility functions to parametrize Tensors on existing Modules.
@@ -396,7 +397,7 @@ Note that these functions can be used to parametrize a given Parameter
 or Buffer given a specific function that maps from an input space to the
 parametrized space. They are not parameterizations that would transform
 an object into a parameter. See the
-`Parametrizations <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`__ tutorial
+`Parametrizations tutorial <https://pytorch.org/tutorials/intermediate/parametrizations.html>`_
 for more information on how to implement your own parametrizations.
 
 .. autosummary::
index c9815db..c6d0e78 100644 (file)
@@ -4518,6 +4518,139 @@ class TestNN(NNTestCase):
         m = pickle.loads(pickle.dumps(m))
         self.assertIsInstance(m, nn.Linear)
 
+    def test_orthogonal_parametrization(self):
+        # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
+
+        def assert_is_orthogonal(X):
+            n, k = X.size(-2), X.size(-1)
+            if n < k:
+                X = X.transpose(-2, -1)
+                n, k = k, n
+            Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k)
+            eps = 10 * n * torch.finfo(X.dtype).eps
+            torch.testing.assert_allclose(X.transpose(-2, -1).conj() @ X, Id, atol=eps, rtol=0.)
+
+
+        def assert_weight_allclose_Q(weight, W):
+            # Test that weight is equal to the Q part of the QR decomposition of W
+            # (or of its transpose if the matrix is wide)
+            wide_matrix = W.size(-2) < W.size(-1)
+            if wide_matrix:
+                W = W.transpose(-2, -1)
+            Q, R = torch.linalg.qr(W)
+            Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
+            if wide_matrix:
+                Q = Q.transpose(-2, -1)
+            torch.testing.assert_allclose(Q, weight, atol=1e-5, rtol=0.)
+
+
+        for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)),  # square/ tall / wide
+                                                (torch.float32, torch.complex64),
+                                                (True, False)):
+            # Conv2d does not support complex yet
+            if not use_linear and dtype.is_complex:
+                continue
+
+            if use_linear:
+                input = torch.randn(3, shape[0], dtype=dtype)
+            else:
+                input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)
+
+            for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"),
+                                                               (False, True)):
+                # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False
+                # See Note [right_inverse expm cayley]
+                can_initialize = use_trivialization or parametrization == "householder"
+
+                # We generate them every time to always start with fresh weights
+                if use_linear:
+                    m = nn.Linear(*shape, dtype=dtype)
+                else:
+                    m = nn.Conv2d(2, 3, shape, dtype=dtype)
+
+                # We do not support householder for complex inputs
+                # See Note [Householder complex]
+                w_init = m.weight.clone()
+                if parametrization == "householder" and m.weight.is_complex():
+                    msg = "householder parametrization does not support complex tensors"
+                    with self.assertRaisesRegex(ValueError, msg):
+                        torch.nn.utils.parametrizations.orthogonal(m,
+                                                                   "weight",
+                                                                   parametrization,
+                                                                   use_trivialization=use_trivialization)
+                    continue
+
+                wide_matrix = w_init.size(-2) < w_init.size(-1)
+                torch.nn.utils.parametrizations.orthogonal(m,
+                                                           "weight",
+                                                           parametrization,
+                                                           use_trivialization=use_trivialization)
+                # Forwards works as expected
+                self.assertEqual(w_init.shape, m.weight.shape)
+                assert_is_orthogonal(m.weight)
+                if can_initialize:
+                    assert_weight_allclose_Q(m.weight, w_init)
+
+                # Intializing with a given orthogonal matrix works
+                X = torch.randn_like(m.weight)
+                if wide_matrix:
+                    X = X.transpose(-2, -1)
+                w_new = torch.linalg.qr(X).Q
+                if wide_matrix:
+                    w_new = w_new.transpose(-2, -1)
+                if can_initialize:
+                    m.weight = w_new
+                    torch.testing.assert_allclose(w_new, m.weight, atol=1e-5, rtol=0.)
+                else:
+                    msg = "assign to the matrix exponential or the Cayley parametrization"
+                    with self.assertRaisesRegex(NotImplementedError, msg):
+                        m.weight = w_new
+
+                # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix
+                w_new = torch.randn_like(m.weight)
+                if can_initialize:
+                    m.weight = w_new
+                    assert_weight_allclose_Q(m.weight, w_new)
+                else:
+                    msg = "assign to the matrix exponential or the Cayley parametrization"
+                    with self.assertRaisesRegex(NotImplementedError, msg):
+                        m.weight = w_new
+
+                opt = torch.optim.SGD(m.parameters(), lr=0.1)
+                for _ in range(2):
+                    opt.zero_grad()
+                    m(input).norm().backward()
+                    grad = m.parametrizations.weight.original.grad
+                    self.assertIsNotNone(grad)
+                    # We do not update the upper triangular part of the matrix if tall tril if wide
+                    if grad.size(-2) >= grad.size(-1):
+                        zeros_grad = grad.triu(1)
+                    else:
+                        zeros_grad = grad.tril(-1)
+                    self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
+                    # The gradient in the diagonal can only be imaginary because a skew-Hermitian
+                    # matrix has imaginary diagonal
+                    diag_grad = grad.diagonal(dim1=-2, dim2=-1)
+                    if grad.is_complex():
+                        diag_grad = diag_grad.real
+                    self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
+                    opt.step()
+                    assert_is_orthogonal(m.weight)
+
+    def test_orthogonal_errors(self):
+        m = nn.Linear(3, 4)
+        with self.assertRaisesRegex(ValueError, "has to be one of"):
+            torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")
+
+        with self.assertRaisesRegex(ValueError, "Expected a matrix"):
+            torch.nn.utils.parametrizations.orthogonal(m, "bias")
+
+        torch.nn.utils.parametrizations.orthogonal(m, "weight")
+        with self.assertRaisesRegex(ValueError, "matrices of shape"):
+            m.weight = torch.randn(5, 5)
+        torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
+
+
     def test_threshold_int(self):
         x = torch.tensor([-3, -2, -1, 0, 1, 2, 3])
         expected = torch.tensor([99, 99, 99, 99, 1, 2, 3])
index de3d5c7..de67aa8 100644 (file)
+from enum import Enum, auto
+
 import torch
+from torch import Tensor
 from ..utils import parametrize
 from ..modules import Module
 from .. import functional as F
 
 from typing import Optional
 
+
+def _is_orthogonal(Q, eps=None):
+    n, k = Q.size(-2), Q.size(-1)
+    Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
+    # A reasonable eps, but not too large
+    eps = 10. * n * torch.finfo(Q.dtype).eps
+    return torch.allclose(Q.transpose(-2, -1).conj() @ Q, Id, atol=eps)
+
+
+def _make_orthogonal(A):
+    """ Assume that A is a tall matrix.
+    Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative
+    """
+    X, tau = torch.geqrf(A)
+    Q = torch.linalg.householder_product(X, tau)
+    # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
+    Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
+    return Q
+
+
+class _OrthMaps(Enum):
+    matrix_exp = auto()
+    cayley = auto()
+    householder = auto()
+
+
+class _Orthogonal(Module):
+    base: Tensor
+
+    def __init__(self,
+                 weight,
+                 orthogonal_map: _OrthMaps,
+                 *,
+                 use_trivialization=True) -> None:
+        super().__init__()
+
+        # Note [Householder complex]
+        # For complex tensors, it is not possible to compute the tensor `tau` necessary for
+        # linalg.householder_product from the reflectors.
+        # To see this, note that the reflectors have a shape like:
+        # 0 0 0
+        # * 0 0
+        # * * 0
+        # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
+        # to parametrize the unitary matrices. Saving tau on its own does not work either, because
+        # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
+        # them as independent tensors we would not maintain the constraint
+        # An equivalent reasoning holds for rectangular matrices
+        if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
+            raise ValueError("The householder parametrization does not support complex tensors.")
+
+        self.shape = weight.shape
+        self.orthogonal_map = orthogonal_map
+        if use_trivialization:
+            self.register_buffer("base", None)
+
+    def forward(self, X: torch.Tensor) -> torch.Tensor:
+        n, k = X.size(-2), X.size(-1)
+        transposed = n < k
+        if transposed:
+            X = X.transpose(-2, -1)
+            n, k = k, n
+        # Here n > k and X is a tall matrix
+        if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
+            # We just need n x k - k(k-1)/2 parameters
+            X = X.tril()
+            if n != k:
+                # Embed into a square matrix
+                X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
+            A = X - X.transpose(-2, -1).conj()
+            # A is skew-symmetric (or skew-hermitian)
+            if self.orthogonal_map == _OrthMaps.matrix_exp:
+                Q = torch.matrix_exp(A)
+            elif self.orthogonal_map == _OrthMaps.cayley:
+                # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
+                Id = torch.eye(n, dtype=A.dtype, device=A.device)
+                Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
+            # Q is now orthogonal (or unitary) of size (..., n, n)
+            if n != k:
+                Q = Q[..., :k]
+            # Q is now the size of the X (albeit perhaps transposed)
+        else:
+            # X is real here, as we do not support householder with complex numbers
+            A = X.tril(diagonal=-1)
+            tau = 2. / (1. + (A * A).sum(dim=-2))
+            Q = torch.linalg.householder_product(A, tau)
+            # The diagonal of X is 1's and -1's
+            # We do not want to differentiate through this or update the diagonal of X hence the casting
+            Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
+
+        if hasattr(self, "base"):
+            Q = self.base @ Q
+        if transposed:
+            Q = Q.transpose(-2, -1)
+        return Q
+
+    @torch.autograd.no_grad()
+    def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
+        if Q.shape != self.shape:
+            raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
+                             f"Got a tensor of shape {Q.shape}.")
+
+        Q_init = Q
+        n, k = Q.size(-2), Q.size(-1)
+        transpose = n < k
+        if transpose:
+            Q = Q.transpose(-2, -1)
+            n, k = k, n
+
+        # We always make sure to always copy Q in every path
+        if not hasattr(self, "base"):
+            # Note [right_inverse expm cayley]
+            # If we do not have use_trivialization=True, we just implement the inverse of the forward
+            # map for the Householder. To see why, think that for the Cayley map,
+            # we would need to find the matrix X \in R^{n x k} such that:
+            # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
+            # A = Y - Y.transpose(-2, -1).conj()
+            # cayley(A)[:, :k]
+            # gives the original tensor. It is not clear how to do this.
+            # Perhaps via some algebraic manipulation involving the QR like that of
+            # Corollary 2.2 in Edelman, Arias and Smith?
+            if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
+                raise NotImplementedError("It is not possible to assign to the matrix exponential "
+                                          "or the Cayley parametrizations when use_trivialization=False.")
+
+            # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
+            # Here Q is always real because we do not support householder and complex matrices.
+            # See note [Householder complex]
+            A, tau = torch.geqrf(Q)
+            # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
+            # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
+            # The diagonal of Q is the diagonal of R from the qr decomposition
+            A.diagonal(dim1=-2, dim2=-1).sign_()
+            # Equality with zero is ok because LAPACK returns exactly zero when it does not want
+            # to use a particular reflection
+            A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
+            return A.transpose(-2, -1) if transpose else A
+        else:
+            if n == k:
+                # We check whether Q is orthogonal
+                if not _is_orthogonal(Q):
+                    Q = _make_orthogonal(Q)
+                else:  # Is orthogonal
+                    Q = Q.clone()
+            else:
+                # Complete Q into a full n x n orthogonal matrix
+                N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
+                Q = torch.cat([Q, N], dim=-1)
+                Q = _make_orthogonal(Q)
+            self.base = Q
+
+            # It is necessary to return the -Id, as we use the diagonal for the
+            # Householder parametrization. Using -Id makes:
+            # householder(torch.zeros(m,n)) == torch.eye(m,n)
+            # Poor man's version of eye_like
+            neg_Id = torch.zeros_like(Q_init)
+            neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
+            return neg_Id
+
+
+def orthogonal(module: Module,
+               name: str = 'weight',
+               orthogonal_map: Optional[str] = None,
+               *,
+               use_trivialization: bool = True) -> Module:
+    r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
+
+    Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
+    matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
+
+    .. math::
+
+        \begin{align*}
+            Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
+            QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
+        \end{align*}
+
+    where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
+    and the transpose when :math:`Q` is real-valued, and
+    :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
+    In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
+    and orthonormal rows otherwise.
+
+    If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
+
+    The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
+
+    - ``"matrix_exp"``/``"cayley"``:
+      the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
+      :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
+      :math:`A` to give an orthogonal matrix.
+    - ``"householder"``: computes a product of Householder reflectors
+      (:func:`~torch.linalg.householder_product`).
+
+    ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
+    ``"householder"``, but they are slower to compute for very thin or very wide matrices.
+
+    If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
+    where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
+    ``module.parametrizations.weight[0].base``. This helps the
+    convergence of the parametrized layer at the expense of some extra memory use.
+    See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
+
+    Initial value of :math:`Q`:
+    If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
+    of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
+    and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
+    Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
+    Otherwise, the initial value is the result of the composition of all the registered
+    parametrizations applied to the original tensor.
+
+    .. note::
+        This function is implemented using the parametrization functionality
+        in :func:`~torch.nn.utils.parametrize.register_parametrization`.
+
+
+    .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
+    .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
+
+    Args:
+        module (nn.Module): module on which to register the parametrization.
+        name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
+        orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
+            Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
+        use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
+            Default: ``True``.
+
+    Returns:
+        The original module with an orthogonal parametrization registered to the specified
+        weight
+
+    Example::
+
+        >>> orth_linear = orthogonal(nn.Linear(20, 40))
+        >>> orth_linear
+        ParametrizedLinear(
+        in_features=20, out_features=40, bias=True
+        (parametrizations): ModuleDict(
+            (weight): ParametrizationList(
+            (0): _Orthogonal()
+            )
+        )
+        )
+        >>> Q = orth_linear.weight
+        >>> torch.dist(Q.T @ Q, torch.eye(20))
+        tensor(4.9332e-07)
+    """
+    weight = getattr(module, name, None)
+    if not isinstance(weight, Tensor):
+        raise ValueError(
+            "Module '{}' has no parameter ot buffer with name '{}'".format(module, name)
+        )
+
+    # We could implement this for 1-dim tensors as the maps on the sphere
+    # but I believe it'd bite more people than it'd help
+    if weight.ndim < 2:
+        raise ValueError("Expected a matrix or batch of matrices. "
+                         f"Got a tensor of {weight.ndim} dimensions.")
+
+    if orthogonal_map is None:
+        orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
+
+    orth_enum = getattr(_OrthMaps, orthogonal_map, None)
+    if orth_enum is None:
+        raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
+                         f'Got: {orthogonal_map}')
+    orth = _Orthogonal(weight,
+                       orth_enum,
+                       use_trivialization=use_trivialization)
+    parametrize.register_parametrization(module, name, orth, unsafe=True)
+    return module
+
+
 class _SpectralNorm(Module):
     def __init__(
         self,
@@ -147,8 +423,8 @@ def spectral_norm(module: Module,
     .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
 
     .. note::
-        This function is implemented using the new parametrization functionality
-        in :func:`torch.nn.utils.parametrize.register_parametrization`. It is a
+        This function is implemented using the parametrization functionality
+        in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
         reimplementation of :func:`torch.nn.utils.spectral_norm`.
 
     .. note::
@@ -165,13 +441,13 @@ def spectral_norm(module: Module,
 
     Args:
         module (nn.Module): containing module
-        name (str, optional): name of weight parameter
+        name (str, optional): name of weight parameter. Default: ``"weight"``.
         n_power_iterations (int, optional): number of power iterations to
-            calculate spectral norm
+            calculate spectral norm. Default: ``1``.
         eps (float, optional): epsilon for numerical stability in
-            calculating norms
-        dim (int, optional): dimension corresponding to number of outputs,
-            the default is ``0``, except for modules that are instances of
+            calculating norms. Default: ``1e-12``.
+        dim (int, optional): dimension corresponding to number of outputs.
+            Default: ``0``, except for modules that are instances of
             ConvTranspose{1,2,3}d, when it is ``1``
 
     Returns:
@@ -193,13 +469,11 @@ def spectral_norm(module: Module,
         >>> torch.linalg.matrix_norm(snm.weight, 2)
         tensor(1.0000, grad_fn=<CopyBackwards>)
     """
-    if not hasattr(module, name):
+    weight = getattr(module, name, None)
+    if not isinstance(weight, Tensor):
         raise ValueError(
-            "Module '{}' has no attribute with name '{}'".format(module, name)
+            "Module '{}' has no parameter or buffer with name '{}'".format(module, name)
         )
-    # getattr should get the correct parametrized weight if there
-    # is already an parametrization registered
-    weight = getattr(module, name)
 
     if dim is None:
         if isinstance(module, (torch.nn.ConvTranspose1d,
index 332fe76..d8f2a94 100644 (file)
@@ -129,8 +129,11 @@ class ParametrizationList(ModuleList):
             new = original
             for module in reversed(self):  # type: ignore[call-overload]
                 if hasattr(module, "right_inverse"):
-                    new = module.right_inverse(new)
-                # else, we assume that right_inverse is the identity
+                    try:
+                        new = module.right_inverse(new)
+                    except NotImplementedError:
+                        pass
+                # else, or if it throws, we assume that right_inverse is the identity
 
         if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence):
             raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
@@ -209,7 +212,9 @@ class ParametrizationList(ModuleList):
             for module in reversed(self):  # type: ignore[call-overload]
                 if hasattr(module, "right_inverse"):
                     value = module.right_inverse(value)
-                # else we assume that right_inverse is the identity
+                else:
+                    raise RuntimeError(f"parametrization {type(module).__name__} does not implement "
+                                       "right_inverse.")
             if self.is_tensor:
                 # These exceptions should only throw when a right_inverse function does not
                 # return the same dtype for every input, which should most likely be caused by a bug
@@ -372,16 +377,12 @@ def register_parametrization(
 
         def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
 
-    If this method is not implemented, it defaults to the identity.
     This method is called on the unparametrized tensor when the first parametrization
-    is registered.
+    is registered to compute the initial value of the original tensor.
+    If this method is not implemented, the original tensor will be just the unparametrized tensor.
 
-    In most situations, ``right_inverse`` will be a function such that
-    ``forward(right_inverse(X)) == X`` (see
-    `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
-    Sometimes, when the parametrization is not surjective, it may be reasonable
-    to relax this.
-    This may be used to initialize the tensor, as shown in the example below.
+    If all the parametrizations registered on a tensor implement `right_inverse` it is possible
+    to initialize a parametrized tensor by assigning to it, as shown in the example below.
 
     It is possible for the first parametrization to depend on several inputs.
     This may be implemented returning a tuple of tensors from ``right_inverse``
@@ -397,6 +398,14 @@ def register_parametrization(
         If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
         and nothing will be called otherwise.
 
+    .. note::
+
+        In most situations, ``right_inverse`` will be a function such that
+        ``forward(right_inverse(X)) == X`` (see
+        `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
+        Sometimes, when the parametrization is not surjective, it may be reasonable
+        to relax this.
+
     .. warning::
 
         If a parametrization depends on several inputs, :func:`~register_parametrization`
@@ -483,25 +492,29 @@ def register_parametrization(
                     f"parametrization(module.{tensor_name}).shape: {X.shape}"
                 )
             if hasattr(parametrization, "right_inverse"):
-                Z = parametrization.right_inverse(X)  # type: ignore[operator]
-                if not isinstance(Z, Tensor):
-                    raise ValueError(
-                        f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
-                    )
-                if Z.dtype != Y.dtype:
-                    raise ValueError(
-                        "The tensor returned by parametrization.right_inverse must have the same dtype "
-                        f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
-                        f"module.{tensor_name}.dtype: {Y.dtype}\n"
-                        f"returned dtype: {Z.dtype}"
-                    )
-                if Z.shape != Y.shape:
-                    raise ValueError(
-                        "The tensor returned by parametrization.right_inverse must have the same shape "
-                        f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
-                        f"module.{tensor_name}.shape: {Y.shape}\n"
-                        f"returned shape: {Z.shape}"
-                    )
+                try:
+                    Z = parametrization.right_inverse(X)  # type: ignore[operator]
+                except NotImplementedError:
+                    pass
+                else:
+                    if not isinstance(Z, Tensor):
+                        raise ValueError(
+                            f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
+                        )
+                    if Z.dtype != Y.dtype:
+                        raise ValueError(
+                            "The tensor returned by parametrization.right_inverse must have the same dtype "
+                            f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
+                            f"module.{tensor_name}.dtype: {Y.dtype}\n"
+                            f"returned dtype: {Z.dtype}"
+                        )
+                    if Z.shape != Y.shape:
+                        raise ValueError(
+                            "The tensor returned by parametrization.right_inverse must have the same shape "
+                            f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
+                            f"module.{tensor_name}.shape: {Y.shape}\n"
+                            f"returned shape: {Z.shape}"
+                        )
             # else right_inverse is assumed to be the identity
 
         # add the new parametrization to the parametrization list