From: lezcano Date: Mon, 30 Aug 2021 20:10:23 +0000 (-0700) Subject: Implements the orthogonal parametrization (#62089) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~589 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f3e329cbec5f4f32e195bbe3b8b5b4d2b1323128;p=platform%2Fupstream%2Fpytorch.git Implements the orthogonal parametrization (#62089) Summary: Implements an orthogonal / unitary parametrisation. It does passes the tests and I have trained a couple models with this implementation, so I believe it should be somewhat correct. Now, the implementation is very subtle. I'm tagging nikitaved and IvanYashchuk as reviewers in case they have comments / they see some room for optimisation of the code, in particular of the `forward` function. Fixes https://github.com/pytorch/pytorch/issues/42243 Pull Request resolved: https://github.com/pytorch/pytorch/pull/62089 Reviewed By: ezyang Differential Revision: D30639063 Pulled By: albanD fbshipit-source-id: 988664f333ac7a75ce71ba44c8d77b986dff2fe6 --- diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 07ce4db..6eca9d4 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -389,6 +389,7 @@ in :func:`torch.nn.utils.parameterize.register_parametrization`. :toctree: generated :nosignatures: + parametrizations.orthogonal parametrizations.spectral_norm Utility functions to parametrize Tensors on existing Modules. @@ -396,7 +397,7 @@ Note that these functions can be used to parametrize a given Parameter or Buffer given a specific function that maps from an input space to the parametrized space. They are not parameterizations that would transform an object into a parameter. See the -`Parametrizations `__ tutorial +`Parametrizations tutorial `_ for more information on how to implement your own parametrizations. .. autosummary:: diff --git a/test/test_nn.py b/test/test_nn.py index c9815db..c6d0e78 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4518,6 +4518,139 @@ class TestNN(NNTestCase): m = pickle.loads(pickle.dumps(m)) self.assertIsInstance(m, nn.Linear) + def test_orthogonal_parametrization(self): + # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) + + def assert_is_orthogonal(X): + n, k = X.size(-2), X.size(-1) + if n < k: + X = X.transpose(-2, -1) + n, k = k, n + Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k) + eps = 10 * n * torch.finfo(X.dtype).eps + torch.testing.assert_allclose(X.transpose(-2, -1).conj() @ X, Id, atol=eps, rtol=0.) + + + def assert_weight_allclose_Q(weight, W): + # Test that weight is equal to the Q part of the QR decomposition of W + # (or of its transpose if the matrix is wide) + wide_matrix = W.size(-2) < W.size(-1) + if wide_matrix: + W = W.transpose(-2, -1) + Q, R = torch.linalg.qr(W) + Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + if wide_matrix: + Q = Q.transpose(-2, -1) + torch.testing.assert_allclose(Q, weight, atol=1e-5, rtol=0.) + + + for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)), # square/ tall / wide + (torch.float32, torch.complex64), + (True, False)): + # Conv2d does not support complex yet + if not use_linear and dtype.is_complex: + continue + + if use_linear: + input = torch.randn(3, shape[0], dtype=dtype) + else: + input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype) + + for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"), + (False, True)): + # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False + # See Note [right_inverse expm cayley] + can_initialize = use_trivialization or parametrization == "householder" + + # We generate them every time to always start with fresh weights + if use_linear: + m = nn.Linear(*shape, dtype=dtype) + else: + m = nn.Conv2d(2, 3, shape, dtype=dtype) + + # We do not support householder for complex inputs + # See Note [Householder complex] + w_init = m.weight.clone() + if parametrization == "householder" and m.weight.is_complex(): + msg = "householder parametrization does not support complex tensors" + with self.assertRaisesRegex(ValueError, msg): + torch.nn.utils.parametrizations.orthogonal(m, + "weight", + parametrization, + use_trivialization=use_trivialization) + continue + + wide_matrix = w_init.size(-2) < w_init.size(-1) + torch.nn.utils.parametrizations.orthogonal(m, + "weight", + parametrization, + use_trivialization=use_trivialization) + # Forwards works as expected + self.assertEqual(w_init.shape, m.weight.shape) + assert_is_orthogonal(m.weight) + if can_initialize: + assert_weight_allclose_Q(m.weight, w_init) + + # Intializing with a given orthogonal matrix works + X = torch.randn_like(m.weight) + if wide_matrix: + X = X.transpose(-2, -1) + w_new = torch.linalg.qr(X).Q + if wide_matrix: + w_new = w_new.transpose(-2, -1) + if can_initialize: + m.weight = w_new + torch.testing.assert_allclose(w_new, m.weight, atol=1e-5, rtol=0.) + else: + msg = "assign to the matrix exponential or the Cayley parametrization" + with self.assertRaisesRegex(NotImplementedError, msg): + m.weight = w_new + + # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix + w_new = torch.randn_like(m.weight) + if can_initialize: + m.weight = w_new + assert_weight_allclose_Q(m.weight, w_new) + else: + msg = "assign to the matrix exponential or the Cayley parametrization" + with self.assertRaisesRegex(NotImplementedError, msg): + m.weight = w_new + + opt = torch.optim.SGD(m.parameters(), lr=0.1) + for _ in range(2): + opt.zero_grad() + m(input).norm().backward() + grad = m.parametrizations.weight.original.grad + self.assertIsNotNone(grad) + # We do not update the upper triangular part of the matrix if tall tril if wide + if grad.size(-2) >= grad.size(-1): + zeros_grad = grad.triu(1) + else: + zeros_grad = grad.tril(-1) + self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad)) + # The gradient in the diagonal can only be imaginary because a skew-Hermitian + # matrix has imaginary diagonal + diag_grad = grad.diagonal(dim1=-2, dim2=-1) + if grad.is_complex(): + diag_grad = diag_grad.real + self.assertEqual(diag_grad, torch.zeros_like(diag_grad)) + opt.step() + assert_is_orthogonal(m.weight) + + def test_orthogonal_errors(self): + m = nn.Linear(3, 4) + with self.assertRaisesRegex(ValueError, "has to be one of"): + torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo") + + with self.assertRaisesRegex(ValueError, "Expected a matrix"): + torch.nn.utils.parametrizations.orthogonal(m, "bias") + + torch.nn.utils.parametrizations.orthogonal(m, "weight") + with self.assertRaisesRegex(ValueError, "matrices of shape"): + m.weight = torch.randn(5, 5) + torch.nn.utils.parametrize.remove_parametrizations(m, "weight") + + def test_threshold_int(self): x = torch.tensor([-3, -2, -1, 0, 1, 2, 3]) expected = torch.tensor([99, 99, 99, 99, 1, 2, 3]) diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index de3d5c7..de67aa8 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -1,10 +1,286 @@ +from enum import Enum, auto + import torch +from torch import Tensor from ..utils import parametrize from ..modules import Module from .. import functional as F from typing import Optional + +def _is_orthogonal(Q, eps=None): + n, k = Q.size(-2), Q.size(-1) + Id = torch.eye(k, dtype=Q.dtype, device=Q.device) + # A reasonable eps, but not too large + eps = 10. * n * torch.finfo(Q.dtype).eps + return torch.allclose(Q.transpose(-2, -1).conj() @ Q, Id, atol=eps) + + +def _make_orthogonal(A): + """ Assume that A is a tall matrix. + Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative + """ + X, tau = torch.geqrf(A) + Q = torch.linalg.householder_product(X, tau) + # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs + Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + return Q + + +class _OrthMaps(Enum): + matrix_exp = auto() + cayley = auto() + householder = auto() + + +class _Orthogonal(Module): + base: Tensor + + def __init__(self, + weight, + orthogonal_map: _OrthMaps, + *, + use_trivialization=True) -> None: + super().__init__() + + # Note [Householder complex] + # For complex tensors, it is not possible to compute the tensor `tau` necessary for + # linalg.householder_product from the reflectors. + # To see this, note that the reflectors have a shape like: + # 0 0 0 + # * 0 0 + # * * 0 + # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters + # to parametrize the unitary matrices. Saving tau on its own does not work either, because + # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise + # them as independent tensors we would not maintain the constraint + # An equivalent reasoning holds for rectangular matrices + if weight.is_complex() and orthogonal_map == _OrthMaps.householder: + raise ValueError("The householder parametrization does not support complex tensors.") + + self.shape = weight.shape + self.orthogonal_map = orthogonal_map + if use_trivialization: + self.register_buffer("base", None) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + n, k = X.size(-2), X.size(-1) + transposed = n < k + if transposed: + X = X.transpose(-2, -1) + n, k = k, n + # Here n > k and X is a tall matrix + if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley: + # We just need n x k - k(k-1)/2 parameters + X = X.tril() + if n != k: + # Embed into a square matrix + X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + A = X - X.transpose(-2, -1).conj() + # A is skew-symmetric (or skew-hermitian) + if self.orthogonal_map == _OrthMaps.matrix_exp: + Q = torch.matrix_exp(A) + elif self.orthogonal_map == _OrthMaps.cayley: + # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} + Id = torch.eye(n, dtype=A.dtype, device=A.device) + Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5)) + # Q is now orthogonal (or unitary) of size (..., n, n) + if n != k: + Q = Q[..., :k] + # Q is now the size of the X (albeit perhaps transposed) + else: + # X is real here, as we do not support householder with complex numbers + A = X.tril(diagonal=-1) + tau = 2. / (1. + (A * A).sum(dim=-2)) + Q = torch.linalg.householder_product(A, tau) + # The diagonal of X is 1's and -1's + # We do not want to differentiate through this or update the diagonal of X hence the casting + Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) + + if hasattr(self, "base"): + Q = self.base @ Q + if transposed: + Q = Q.transpose(-2, -1) + return Q + + @torch.autograd.no_grad() + def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: + if Q.shape != self.shape: + raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. " + f"Got a tensor of shape {Q.shape}.") + + Q_init = Q + n, k = Q.size(-2), Q.size(-1) + transpose = n < k + if transpose: + Q = Q.transpose(-2, -1) + n, k = k, n + + # We always make sure to always copy Q in every path + if not hasattr(self, "base"): + # Note [right_inverse expm cayley] + # If we do not have use_trivialization=True, we just implement the inverse of the forward + # map for the Householder. To see why, think that for the Cayley map, + # we would need to find the matrix X \in R^{n x k} such that: + # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + # A = Y - Y.transpose(-2, -1).conj() + # cayley(A)[:, :k] + # gives the original tensor. It is not clear how to do this. + # Perhaps via some algebraic manipulation involving the QR like that of + # Corollary 2.2 in Edelman, Arias and Smith? + if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp: + raise NotImplementedError("It is not possible to assign to the matrix exponential " + "or the Cayley parametrizations when use_trivialization=False.") + + # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. + # Here Q is always real because we do not support householder and complex matrices. + # See note [Householder complex] + A, tau = torch.geqrf(Q) + # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could + # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition + # The diagonal of Q is the diagonal of R from the qr decomposition + A.diagonal(dim1=-2, dim2=-1).sign_() + # Equality with zero is ok because LAPACK returns exactly zero when it does not want + # to use a particular reflection + A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1 + return A.transpose(-2, -1) if transpose else A + else: + if n == k: + # We check whether Q is orthogonal + if not _is_orthogonal(Q): + Q = _make_orthogonal(Q) + else: # Is orthogonal + Q = Q.clone() + else: + # Complete Q into a full n x n orthogonal matrix + N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device) + Q = torch.cat([Q, N], dim=-1) + Q = _make_orthogonal(Q) + self.base = Q + + # It is necessary to return the -Id, as we use the diagonal for the + # Householder parametrization. Using -Id makes: + # householder(torch.zeros(m,n)) == torch.eye(m,n) + # Poor man's version of eye_like + neg_Id = torch.zeros_like(Q_init) + neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.) + return neg_Id + + +def orthogonal(module: Module, + name: str = 'weight', + orthogonal_map: Optional[str] = None, + *, + use_trivialization: bool = True) -> Module: + r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices. + + Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized + matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as + + .. math:: + + \begin{align*} + Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ + QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} + \end{align*} + + where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex + and the transpose when :math:`Q` is real-valued, and + :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` + and orthonormal rows otherwise. + + If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. + + The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: + + - ``"matrix_exp"``/``"cayley"``: + the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ + :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric + :math:`A` to give an orthogonal matrix. + - ``"householder"``: computes a product of Householder reflectors + (:func:`~torch.linalg.householder_product`). + + ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than + ``"householder"``, but they are slower to compute for very thin or very wide matrices. + + If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", + where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under + ``module.parametrizations.weight[0].base``. This helps the + convergence of the parametrized layer at the expense of some extra memory use. + See `Trivializations for Gradient-Based Optimization on Manifolds`_ . + + Initial value of :math:`Q`: + If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value + of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) + and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). + Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. + Otherwise, the initial value is the result of the composition of all the registered + parametrizations applied to the original tensor. + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. + + + .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map + .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 + + Args: + module (nn.Module): module on which to register the parametrization. + name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. + orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. + Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. + use_trivialization (bool, optional): whether to use the dynamic trivialization framework. + Default: ``True``. + + Returns: + The original module with an orthogonal parametrization registered to the specified + weight + + Example:: + + >>> orth_linear = orthogonal(nn.Linear(20, 40)) + >>> orth_linear + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _Orthogonal() + ) + ) + ) + >>> Q = orth_linear.weight + >>> torch.dist(Q.T @ Q, torch.eye(20)) + tensor(4.9332e-07) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + "Module '{}' has no parameter ot buffer with name '{}'".format(module, name) + ) + + # We could implement this for 1-dim tensors as the maps on the sphere + # but I believe it'd bite more people than it'd help + if weight.ndim < 2: + raise ValueError("Expected a matrix or batch of matrices. " + f"Got a tensor of {weight.ndim} dimensions.") + + if orthogonal_map is None: + orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder" + + orth_enum = getattr(_OrthMaps, orthogonal_map, None) + if orth_enum is None: + raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' + f'Got: {orthogonal_map}') + orth = _Orthogonal(weight, + orth_enum, + use_trivialization=use_trivialization) + parametrize.register_parametrization(module, name, orth, unsafe=True) + return module + + class _SpectralNorm(Module): def __init__( self, @@ -147,8 +423,8 @@ def spectral_norm(module: Module, .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 .. note:: - This function is implemented using the new parametrization functionality - in :func:`torch.nn.utils.parametrize.register_parametrization`. It is a + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a reimplementation of :func:`torch.nn.utils.spectral_norm`. .. note:: @@ -165,13 +441,13 @@ def spectral_norm(module: Module, Args: module (nn.Module): containing module - name (str, optional): name of weight parameter + name (str, optional): name of weight parameter. Default: ``"weight"``. n_power_iterations (int, optional): number of power iterations to - calculate spectral norm + calculate spectral norm. Default: ``1``. eps (float, optional): epsilon for numerical stability in - calculating norms - dim (int, optional): dimension corresponding to number of outputs, - the default is ``0``, except for modules that are instances of + calculating norms. Default: ``1e-12``. + dim (int, optional): dimension corresponding to number of outputs. + Default: ``0``, except for modules that are instances of ConvTranspose{1,2,3}d, when it is ``1`` Returns: @@ -193,13 +469,11 @@ def spectral_norm(module: Module, >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0000, grad_fn=) """ - if not hasattr(module, name): + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): raise ValueError( - "Module '{}' has no attribute with name '{}'".format(module, name) + "Module '{}' has no parameter or buffer with name '{}'".format(module, name) ) - # getattr should get the correct parametrized weight if there - # is already an parametrization registered - weight = getattr(module, name) if dim is None: if isinstance(module, (torch.nn.ConvTranspose1d, diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 332fe76..d8f2a94 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -129,8 +129,11 @@ class ParametrizationList(ModuleList): new = original for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): - new = module.right_inverse(new) - # else, we assume that right_inverse is the identity + try: + new = module.right_inverse(new) + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence): raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " @@ -209,7 +212,9 @@ class ParametrizationList(ModuleList): for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): value = module.right_inverse(value) - # else we assume that right_inverse is the identity + else: + raise RuntimeError(f"parametrization {type(module).__name__} does not implement " + "right_inverse.") if self.is_tensor: # These exceptions should only throw when a right_inverse function does not # return the same dtype for every input, which should most likely be caused by a bug @@ -372,16 +377,12 @@ def register_parametrization( def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] - If this method is not implemented, it defaults to the identity. This method is called on the unparametrized tensor when the first parametrization - is registered. + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. - In most situations, ``right_inverse`` will be a function such that - ``forward(right_inverse(X)) == X`` (see - `right inverse `_). - Sometimes, when the parametrization is not surjective, it may be reasonable - to relax this. - This may be used to initialize the tensor, as shown in the example below. + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. It is possible for the first parametrization to depend on several inputs. This may be implemented returning a tuple of tensors from ``right_inverse`` @@ -397,6 +398,14 @@ def register_parametrization( If unsafe=True, then right_inverse will be called if the tensor is not parametrized, and nothing will be called otherwise. + .. note:: + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. + .. warning:: If a parametrization depends on several inputs, :func:`~register_parametrization` @@ -483,25 +492,29 @@ def register_parametrization( f"parametrization(module.{tensor_name}).shape: {X.shape}" ) if hasattr(parametrization, "right_inverse"): - Z = parametrization.right_inverse(X) # type: ignore[operator] - if not isinstance(Z, Tensor): - raise ValueError( - f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" - ) - if Z.dtype != Y.dtype: - raise ValueError( - "The tensor returned by parametrization.right_inverse must have the same dtype " - f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" - f"module.{tensor_name}.dtype: {Y.dtype}\n" - f"returned dtype: {Z.dtype}" - ) - if Z.shape != Y.shape: - raise ValueError( - "The tensor returned by parametrization.right_inverse must have the same shape " - f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" - f"module.{tensor_name}.shape: {Y.shape}\n" - f"returned shape: {Z.shape}" - ) + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) # else right_inverse is assumed to be the identity # add the new parametrization to the parametrization list