m = pickle.loads(pickle.dumps(m))
self.assertIsInstance(m, nn.Linear)
+ def test_orthogonal_parametrization(self):
+ # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
+
+ def assert_is_orthogonal(X):
+ n, k = X.size(-2), X.size(-1)
+ if n < k:
+ X = X.transpose(-2, -1)
+ n, k = k, n
+ Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k)
+ eps = 10 * n * torch.finfo(X.dtype).eps
+ torch.testing.assert_allclose(X.transpose(-2, -1).conj() @ X, Id, atol=eps, rtol=0.)
+
+
+ def assert_weight_allclose_Q(weight, W):
+ # Test that weight is equal to the Q part of the QR decomposition of W
+ # (or of its transpose if the matrix is wide)
+ wide_matrix = W.size(-2) < W.size(-1)
+ if wide_matrix:
+ W = W.transpose(-2, -1)
+ Q, R = torch.linalg.qr(W)
+ Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
+ if wide_matrix:
+ Q = Q.transpose(-2, -1)
+ torch.testing.assert_allclose(Q, weight, atol=1e-5, rtol=0.)
+
+
+ for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)), # square/ tall / wide
+ (torch.float32, torch.complex64),
+ (True, False)):
+ # Conv2d does not support complex yet
+ if not use_linear and dtype.is_complex:
+ continue
+
+ if use_linear:
+ input = torch.randn(3, shape[0], dtype=dtype)
+ else:
+ input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)
+
+ for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"),
+ (False, True)):
+ # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False
+ # See Note [right_inverse expm cayley]
+ can_initialize = use_trivialization or parametrization == "householder"
+
+ # We generate them every time to always start with fresh weights
+ if use_linear:
+ m = nn.Linear(*shape, dtype=dtype)
+ else:
+ m = nn.Conv2d(2, 3, shape, dtype=dtype)
+
+ # We do not support householder for complex inputs
+ # See Note [Householder complex]
+ w_init = m.weight.clone()
+ if parametrization == "householder" and m.weight.is_complex():
+ msg = "householder parametrization does not support complex tensors"
+ with self.assertRaisesRegex(ValueError, msg):
+ torch.nn.utils.parametrizations.orthogonal(m,
+ "weight",
+ parametrization,
+ use_trivialization=use_trivialization)
+ continue
+
+ wide_matrix = w_init.size(-2) < w_init.size(-1)
+ torch.nn.utils.parametrizations.orthogonal(m,
+ "weight",
+ parametrization,
+ use_trivialization=use_trivialization)
+ # Forwards works as expected
+ self.assertEqual(w_init.shape, m.weight.shape)
+ assert_is_orthogonal(m.weight)
+ if can_initialize:
+ assert_weight_allclose_Q(m.weight, w_init)
+
+ # Intializing with a given orthogonal matrix works
+ X = torch.randn_like(m.weight)
+ if wide_matrix:
+ X = X.transpose(-2, -1)
+ w_new = torch.linalg.qr(X).Q
+ if wide_matrix:
+ w_new = w_new.transpose(-2, -1)
+ if can_initialize:
+ m.weight = w_new
+ torch.testing.assert_allclose(w_new, m.weight, atol=1e-5, rtol=0.)
+ else:
+ msg = "assign to the matrix exponential or the Cayley parametrization"
+ with self.assertRaisesRegex(NotImplementedError, msg):
+ m.weight = w_new
+
+ # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix
+ w_new = torch.randn_like(m.weight)
+ if can_initialize:
+ m.weight = w_new
+ assert_weight_allclose_Q(m.weight, w_new)
+ else:
+ msg = "assign to the matrix exponential or the Cayley parametrization"
+ with self.assertRaisesRegex(NotImplementedError, msg):
+ m.weight = w_new
+
+ opt = torch.optim.SGD(m.parameters(), lr=0.1)
+ for _ in range(2):
+ opt.zero_grad()
+ m(input).norm().backward()
+ grad = m.parametrizations.weight.original.grad
+ self.assertIsNotNone(grad)
+ # We do not update the upper triangular part of the matrix if tall tril if wide
+ if grad.size(-2) >= grad.size(-1):
+ zeros_grad = grad.triu(1)
+ else:
+ zeros_grad = grad.tril(-1)
+ self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
+ # The gradient in the diagonal can only be imaginary because a skew-Hermitian
+ # matrix has imaginary diagonal
+ diag_grad = grad.diagonal(dim1=-2, dim2=-1)
+ if grad.is_complex():
+ diag_grad = diag_grad.real
+ self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
+ opt.step()
+ assert_is_orthogonal(m.weight)
+
+ def test_orthogonal_errors(self):
+ m = nn.Linear(3, 4)
+ with self.assertRaisesRegex(ValueError, "has to be one of"):
+ torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")
+
+ with self.assertRaisesRegex(ValueError, "Expected a matrix"):
+ torch.nn.utils.parametrizations.orthogonal(m, "bias")
+
+ torch.nn.utils.parametrizations.orthogonal(m, "weight")
+ with self.assertRaisesRegex(ValueError, "matrices of shape"):
+ m.weight = torch.randn(5, 5)
+ torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
+
+
def test_threshold_int(self):
x = torch.tensor([-3, -2, -1, 0, 1, 2, 3])
expected = torch.tensor([99, 99, 99, 99, 1, 2, 3])
+from enum import Enum, auto
+
import torch
+from torch import Tensor
from ..utils import parametrize
from ..modules import Module
from .. import functional as F
from typing import Optional
+
+def _is_orthogonal(Q, eps=None):
+ n, k = Q.size(-2), Q.size(-1)
+ Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
+ # A reasonable eps, but not too large
+ eps = 10. * n * torch.finfo(Q.dtype).eps
+ return torch.allclose(Q.transpose(-2, -1).conj() @ Q, Id, atol=eps)
+
+
+def _make_orthogonal(A):
+ """ Assume that A is a tall matrix.
+ Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative
+ """
+ X, tau = torch.geqrf(A)
+ Q = torch.linalg.householder_product(X, tau)
+ # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
+ Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
+ return Q
+
+
+class _OrthMaps(Enum):
+ matrix_exp = auto()
+ cayley = auto()
+ householder = auto()
+
+
+class _Orthogonal(Module):
+ base: Tensor
+
+ def __init__(self,
+ weight,
+ orthogonal_map: _OrthMaps,
+ *,
+ use_trivialization=True) -> None:
+ super().__init__()
+
+ # Note [Householder complex]
+ # For complex tensors, it is not possible to compute the tensor `tau` necessary for
+ # linalg.householder_product from the reflectors.
+ # To see this, note that the reflectors have a shape like:
+ # 0 0 0
+ # * 0 0
+ # * * 0
+ # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
+ # to parametrize the unitary matrices. Saving tau on its own does not work either, because
+ # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
+ # them as independent tensors we would not maintain the constraint
+ # An equivalent reasoning holds for rectangular matrices
+ if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
+ raise ValueError("The householder parametrization does not support complex tensors.")
+
+ self.shape = weight.shape
+ self.orthogonal_map = orthogonal_map
+ if use_trivialization:
+ self.register_buffer("base", None)
+
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ n, k = X.size(-2), X.size(-1)
+ transposed = n < k
+ if transposed:
+ X = X.transpose(-2, -1)
+ n, k = k, n
+ # Here n > k and X is a tall matrix
+ if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
+ # We just need n x k - k(k-1)/2 parameters
+ X = X.tril()
+ if n != k:
+ # Embed into a square matrix
+ X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
+ A = X - X.transpose(-2, -1).conj()
+ # A is skew-symmetric (or skew-hermitian)
+ if self.orthogonal_map == _OrthMaps.matrix_exp:
+ Q = torch.matrix_exp(A)
+ elif self.orthogonal_map == _OrthMaps.cayley:
+ # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
+ Id = torch.eye(n, dtype=A.dtype, device=A.device)
+ Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
+ # Q is now orthogonal (or unitary) of size (..., n, n)
+ if n != k:
+ Q = Q[..., :k]
+ # Q is now the size of the X (albeit perhaps transposed)
+ else:
+ # X is real here, as we do not support householder with complex numbers
+ A = X.tril(diagonal=-1)
+ tau = 2. / (1. + (A * A).sum(dim=-2))
+ Q = torch.linalg.householder_product(A, tau)
+ # The diagonal of X is 1's and -1's
+ # We do not want to differentiate through this or update the diagonal of X hence the casting
+ Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
+
+ if hasattr(self, "base"):
+ Q = self.base @ Q
+ if transposed:
+ Q = Q.transpose(-2, -1)
+ return Q
+
+ @torch.autograd.no_grad()
+ def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
+ if Q.shape != self.shape:
+ raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
+ f"Got a tensor of shape {Q.shape}.")
+
+ Q_init = Q
+ n, k = Q.size(-2), Q.size(-1)
+ transpose = n < k
+ if transpose:
+ Q = Q.transpose(-2, -1)
+ n, k = k, n
+
+ # We always make sure to always copy Q in every path
+ if not hasattr(self, "base"):
+ # Note [right_inverse expm cayley]
+ # If we do not have use_trivialization=True, we just implement the inverse of the forward
+ # map for the Householder. To see why, think that for the Cayley map,
+ # we would need to find the matrix X \in R^{n x k} such that:
+ # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
+ # A = Y - Y.transpose(-2, -1).conj()
+ # cayley(A)[:, :k]
+ # gives the original tensor. It is not clear how to do this.
+ # Perhaps via some algebraic manipulation involving the QR like that of
+ # Corollary 2.2 in Edelman, Arias and Smith?
+ if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
+ raise NotImplementedError("It is not possible to assign to the matrix exponential "
+ "or the Cayley parametrizations when use_trivialization=False.")
+
+ # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
+ # Here Q is always real because we do not support householder and complex matrices.
+ # See note [Householder complex]
+ A, tau = torch.geqrf(Q)
+ # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
+ # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
+ # The diagonal of Q is the diagonal of R from the qr decomposition
+ A.diagonal(dim1=-2, dim2=-1).sign_()
+ # Equality with zero is ok because LAPACK returns exactly zero when it does not want
+ # to use a particular reflection
+ A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
+ return A.transpose(-2, -1) if transpose else A
+ else:
+ if n == k:
+ # We check whether Q is orthogonal
+ if not _is_orthogonal(Q):
+ Q = _make_orthogonal(Q)
+ else: # Is orthogonal
+ Q = Q.clone()
+ else:
+ # Complete Q into a full n x n orthogonal matrix
+ N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
+ Q = torch.cat([Q, N], dim=-1)
+ Q = _make_orthogonal(Q)
+ self.base = Q
+
+ # It is necessary to return the -Id, as we use the diagonal for the
+ # Householder parametrization. Using -Id makes:
+ # householder(torch.zeros(m,n)) == torch.eye(m,n)
+ # Poor man's version of eye_like
+ neg_Id = torch.zeros_like(Q_init)
+ neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
+ return neg_Id
+
+
+def orthogonal(module: Module,
+ name: str = 'weight',
+ orthogonal_map: Optional[str] = None,
+ *,
+ use_trivialization: bool = True) -> Module:
+ r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
+
+ Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
+ matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
+
+ .. math::
+
+ \begin{align*}
+ Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
+ QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
+ \end{align*}
+
+ where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
+ and the transpose when :math:`Q` is real-valued, and
+ :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
+ In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
+ and orthonormal rows otherwise.
+
+ If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
+
+ The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
+
+ - ``"matrix_exp"``/``"cayley"``:
+ the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
+ :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
+ :math:`A` to give an orthogonal matrix.
+ - ``"householder"``: computes a product of Householder reflectors
+ (:func:`~torch.linalg.householder_product`).
+
+ ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
+ ``"householder"``, but they are slower to compute for very thin or very wide matrices.
+
+ If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
+ where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
+ ``module.parametrizations.weight[0].base``. This helps the
+ convergence of the parametrized layer at the expense of some extra memory use.
+ See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
+
+ Initial value of :math:`Q`:
+ If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
+ of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
+ and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
+ Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
+ Otherwise, the initial value is the result of the composition of all the registered
+ parametrizations applied to the original tensor.
+
+ .. note::
+ This function is implemented using the parametrization functionality
+ in :func:`~torch.nn.utils.parametrize.register_parametrization`.
+
+
+ .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
+ .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
+
+ Args:
+ module (nn.Module): module on which to register the parametrization.
+ name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
+ orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
+ Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
+ use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
+ Default: ``True``.
+
+ Returns:
+ The original module with an orthogonal parametrization registered to the specified
+ weight
+
+ Example::
+
+ >>> orth_linear = orthogonal(nn.Linear(20, 40))
+ >>> orth_linear
+ ParametrizedLinear(
+ in_features=20, out_features=40, bias=True
+ (parametrizations): ModuleDict(
+ (weight): ParametrizationList(
+ (0): _Orthogonal()
+ )
+ )
+ )
+ >>> Q = orth_linear.weight
+ >>> torch.dist(Q.T @ Q, torch.eye(20))
+ tensor(4.9332e-07)
+ """
+ weight = getattr(module, name, None)
+ if not isinstance(weight, Tensor):
+ raise ValueError(
+ "Module '{}' has no parameter ot buffer with name '{}'".format(module, name)
+ )
+
+ # We could implement this for 1-dim tensors as the maps on the sphere
+ # but I believe it'd bite more people than it'd help
+ if weight.ndim < 2:
+ raise ValueError("Expected a matrix or batch of matrices. "
+ f"Got a tensor of {weight.ndim} dimensions.")
+
+ if orthogonal_map is None:
+ orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
+
+ orth_enum = getattr(_OrthMaps, orthogonal_map, None)
+ if orth_enum is None:
+ raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
+ f'Got: {orthogonal_map}')
+ orth = _Orthogonal(weight,
+ orth_enum,
+ use_trivialization=use_trivialization)
+ parametrize.register_parametrization(module, name, orth, unsafe=True)
+ return module
+
+
class _SpectralNorm(Module):
def __init__(
self,
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
.. note::
- This function is implemented using the new parametrization functionality
- in :func:`torch.nn.utils.parametrize.register_parametrization`. It is a
+ This function is implemented using the parametrization functionality
+ in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
reimplementation of :func:`torch.nn.utils.spectral_norm`.
.. note::
Args:
module (nn.Module): containing module
- name (str, optional): name of weight parameter
+ name (str, optional): name of weight parameter. Default: ``"weight"``.
n_power_iterations (int, optional): number of power iterations to
- calculate spectral norm
+ calculate spectral norm. Default: ``1``.
eps (float, optional): epsilon for numerical stability in
- calculating norms
- dim (int, optional): dimension corresponding to number of outputs,
- the default is ``0``, except for modules that are instances of
+ calculating norms. Default: ``1e-12``.
+ dim (int, optional): dimension corresponding to number of outputs.
+ Default: ``0``, except for modules that are instances of
ConvTranspose{1,2,3}d, when it is ``1``
Returns:
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0000, grad_fn=<CopyBackwards>)
"""
- if not hasattr(module, name):
+ weight = getattr(module, name, None)
+ if not isinstance(weight, Tensor):
raise ValueError(
- "Module '{}' has no attribute with name '{}'".format(module, name)
+ "Module '{}' has no parameter or buffer with name '{}'".format(module, name)
)
- # getattr should get the correct parametrized weight if there
- # is already an parametrization registered
- weight = getattr(module, name)
if dim is None:
if isinstance(module, (torch.nn.ConvTranspose1d,
new = original
for module in reversed(self): # type: ignore[call-overload]
if hasattr(module, "right_inverse"):
- new = module.right_inverse(new)
- # else, we assume that right_inverse is the identity
+ try:
+ new = module.right_inverse(new)
+ except NotImplementedError:
+ pass
+ # else, or if it throws, we assume that right_inverse is the identity
if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence):
raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
for module in reversed(self): # type: ignore[call-overload]
if hasattr(module, "right_inverse"):
value = module.right_inverse(value)
- # else we assume that right_inverse is the identity
+ else:
+ raise RuntimeError(f"parametrization {type(module).__name__} does not implement "
+ "right_inverse.")
if self.is_tensor:
# These exceptions should only throw when a right_inverse function does not
# return the same dtype for every input, which should most likely be caused by a bug
def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
- If this method is not implemented, it defaults to the identity.
This method is called on the unparametrized tensor when the first parametrization
- is registered.
+ is registered to compute the initial value of the original tensor.
+ If this method is not implemented, the original tensor will be just the unparametrized tensor.
- In most situations, ``right_inverse`` will be a function such that
- ``forward(right_inverse(X)) == X`` (see
- `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
- Sometimes, when the parametrization is not surjective, it may be reasonable
- to relax this.
- This may be used to initialize the tensor, as shown in the example below.
+ If all the parametrizations registered on a tensor implement `right_inverse` it is possible
+ to initialize a parametrized tensor by assigning to it, as shown in the example below.
It is possible for the first parametrization to depend on several inputs.
This may be implemented returning a tuple of tensors from ``right_inverse``
If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
and nothing will be called otherwise.
+ .. note::
+
+ In most situations, ``right_inverse`` will be a function such that
+ ``forward(right_inverse(X)) == X`` (see
+ `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
+ Sometimes, when the parametrization is not surjective, it may be reasonable
+ to relax this.
+
.. warning::
If a parametrization depends on several inputs, :func:`~register_parametrization`
f"parametrization(module.{tensor_name}).shape: {X.shape}"
)
if hasattr(parametrization, "right_inverse"):
- Z = parametrization.right_inverse(X) # type: ignore[operator]
- if not isinstance(Z, Tensor):
- raise ValueError(
- f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
- )
- if Z.dtype != Y.dtype:
- raise ValueError(
- "The tensor returned by parametrization.right_inverse must have the same dtype "
- f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
- f"module.{tensor_name}.dtype: {Y.dtype}\n"
- f"returned dtype: {Z.dtype}"
- )
- if Z.shape != Y.shape:
- raise ValueError(
- "The tensor returned by parametrization.right_inverse must have the same shape "
- f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
- f"module.{tensor_name}.shape: {Y.shape}\n"
- f"returned shape: {Z.shape}"
- )
+ try:
+ Z = parametrization.right_inverse(X) # type: ignore[operator]
+ except NotImplementedError:
+ pass
+ else:
+ if not isinstance(Z, Tensor):
+ raise ValueError(
+ f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
+ )
+ if Z.dtype != Y.dtype:
+ raise ValueError(
+ "The tensor returned by parametrization.right_inverse must have the same dtype "
+ f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
+ f"module.{tensor_name}.dtype: {Y.dtype}\n"
+ f"returned dtype: {Z.dtype}"
+ )
+ if Z.shape != Y.shape:
+ raise ValueError(
+ "The tensor returned by parametrization.right_inverse must have the same shape "
+ f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
+ f"module.{tensor_name}.shape: {Y.shape}\n"
+ f"returned shape: {Z.shape}"
+ )
# else right_inverse is assumed to be the identity
# add the new parametrization to the parametrization list