Do not modify saved variables in-place for spectral norm during power iteration ...
authorsoulitzer <soulitzer@gmail.com>
Tue, 24 Aug 2021 20:02:27 +0000 (13:02 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 20:08:59 +0000 (13:08 -0700)
Summary:
Interestingly enough, the original code did have a mechanism that aims to prevent this very issue:
but it performs a clone AFTER modifying u and v in-place.
This wouldn't work though because we can later use the cloned u and v in operations that save for backward, and the next time we execute forward, we modify the same cloned u and v in-place.
So if the idea is that we want to avoid modifying saved variable in-place we should clone it BEFORE the in-place operation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62293

Reviewed By: bdhirsh

Differential Revision: D30489750

Pulled By: soulitzer

fbshipit-source-id: cbe8dea885aef97adda8481f7a822e5bd91f7889

test/test_nn.py
torch/nn/utils/parametrizations.py

index 07a2b48..43e105a 100644 (file)
@@ -4220,6 +4220,9 @@ class TestNN(NNTestCase):
                     out1 = wrapped_m(input)
                     return out0 + out1
 
+                # Make sure we can compute gradients wrt to all the parameters in the case
+                # of double forward
+                fn(input.clone().requires_grad_()).sum().backward()
                 gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False)
 
                 # test removing
index 7941f41..de3d5c7 100644 (file)
@@ -84,6 +84,7 @@ class _SpectralNorm(Module):
 
         # Precondition
         assert weight_mat.ndim > 1
+
         for _ in range(n_power_iterations):
             # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
             # are the first left and right singular vectors.
@@ -92,9 +93,6 @@ class _SpectralNorm(Module):
                                   dim=0, eps=self.eps, out=self._u)   # type: ignore[has-type]
             self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
                                   dim=0, eps=self.eps, out=self._v)   # type: ignore[has-type]
-        # See above on why we need to clone
-        self._u = self._u.clone(memory_format=torch.contiguous_format)
-        self._v = self._v.clone(memory_format=torch.contiguous_format)
 
     def forward(self, weight: torch.Tensor) -> torch.Tensor:
         if weight.ndim == 1:
@@ -104,10 +102,13 @@ class _SpectralNorm(Module):
             weight_mat = self._reshape_weight_to_matrix(weight)
             if self.training:
                 self._power_method(weight_mat, self.n_power_iterations)
+            # See above on why we need to clone
+            u = self._u.clone(memory_format=torch.contiguous_format)
+            v = self._v.clone(memory_format=torch.contiguous_format)
             # The proper way of computing this should be through F.bilinear, but
             # it seems to have some efficiency issues:
             # https://github.com/pytorch/pytorch/issues/58093
-            sigma = torch.dot(self._u, torch.mv(weight_mat, self._v))
+            sigma = torch.dot(u, torch.mv(weight_mat, v))
             return weight / sigma
 
     def right_inverse(self, value: torch.Tensor) -> torch.Tensor: