Fix issue re: DDP and create_graph=True (#63831)
authorRohan Varma <rvarm1@fb.com>
Thu, 26 Aug 2021 06:48:58 +0000 (23:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 06:50:25 +0000 (23:50 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63831

Closes https://github.com/pytorch/pytorch/issues/63812

`at::mul_out` is not supported when `grad` itself requires grad, which is useful for computing higher order derivatives.

In this case, fall back to a mul + copy instead of mul_out.
ghstack-source-id: 136614644

Test Plan: UT

Reviewed By: SciPioneer

Differential Revision: D30505573

fbshipit-source-id: 83532b6207b3d80116fcc4dff0e5520d73b3454f

torch/csrc/distributed/c10d/reducer.cpp
torch/testing/_internal/distributed/distributed_test.py

index d91f191..eafc70c 100644 (file)
@@ -377,9 +377,25 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) {
         if (comm_hook_ == nullptr) {
           auto wrapped =
               at::native::wrapped_scalar_tensor(double(1.) / div_factor_);
-          // Divides while copying into the bucket view to save one scan over
-          // all the input parameters.
-          at::mul_out(bucket_view, grad, wrapped);
+          if (!grad.requires_grad()) {
+            // Divides while copying into the bucket view to save one scan over
+            // all the input parameters.
+            at::mul_out(bucket_view, grad, wrapped);
+          } else {
+            // If DDP is running with create_graph=True, gradients require_grad
+            // themselves in order to compute higher order derivatives. However,
+            // DDP will not sync up these gradients currently (see
+            // https://github.com/pytorch/pytorch/issues/63812).
+            LOG(WARNING)
+                << "Using DistributedDataParallel with create_graph=True "
+                << " is not well-supported. The higher-order gradient will "
+                << " not be synchronized across ranks, and backpropagation "
+                << " through all_reduce operations will not occur. If you require "
+                << " DDP to work with higher-order gradients for your use case, "
+                << " please ping https://github.com/pytorch/pytorch/issues/63929";
+            auto div_result = at::mul(grad, wrapped);
+            bucket_view.copy_(div_result);
+          }
         } else {
           bucket_view.copy_(grad);
         }
index f4bc073..333458c 100644 (file)
@@ -3765,6 +3765,28 @@ class DistributedTest:
             "Only NCCL and GLOO backend support DistributedDataParallel",
         )
         @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_create_graph(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            net = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(1, 1, bias=False).cuda(rank),
+                device_ids=[rank]
+            )
+            inp = torch.randn((2, 1), device=rank)
+            for _ in range(6):
+                loss = net(inp).sum()
+                # Verify DDP works with create_graph=True
+                loss.backward(create_graph=True)
+                # grad tensors should require grad.
+                self.assertTrue(
+                    all([param.requires_grad for param in net.parameters()])
+                )
+
+        @sandcastle_skip_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only NCCL and GLOO backend support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
         @skip_if_rocm
         def test_DistributedDataParallel_non_default_stream(self):
             stream = torch.cuda.Stream(self.rank)