out1 = wrapped_m(input)
return out0 + out1
+ # Make sure we can compute gradients wrt to all the parameters in the case
+ # of double forward
+ fn(input.clone().requires_grad_()).sum().backward()
gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False)
# test removing
# Precondition
assert weight_mat.ndim > 1
+
for _ in range(n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
- # See above on why we need to clone
- self._u = self._u.clone(memory_format=torch.contiguous_format)
- self._v = self._v.clone(memory_format=torch.contiguous_format)
def forward(self, weight: torch.Tensor) -> torch.Tensor:
if weight.ndim == 1:
weight_mat = self._reshape_weight_to_matrix(weight)
if self.training:
self._power_method(weight_mat, self.n_power_iterations)
+ # See above on why we need to clone
+ u = self._u.clone(memory_format=torch.contiguous_format)
+ v = self._v.clone(memory_format=torch.contiguous_format)
# The proper way of computing this should be through F.bilinear, but
# it seems to have some efficiency issues:
# https://github.com/pytorch/pytorch/issues/58093
- sigma = torch.dot(self._u, torch.mv(weight_mat, self._v))
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
return weight / sigma
def right_inverse(self, value: torch.Tensor) -> torch.Tensor: