if (comm_hook_ == nullptr) {
auto wrapped =
at::native::wrapped_scalar_tensor(double(1.) / div_factor_);
- // Divides while copying into the bucket view to save one scan over
- // all the input parameters.
- at::mul_out(bucket_view, grad, wrapped);
+ if (!grad.requires_grad()) {
+ // Divides while copying into the bucket view to save one scan over
+ // all the input parameters.
+ at::mul_out(bucket_view, grad, wrapped);
+ } else {
+ // If DDP is running with create_graph=True, gradients require_grad
+ // themselves in order to compute higher order derivatives. However,
+ // DDP will not sync up these gradients currently (see
+ // https://github.com/pytorch/pytorch/issues/63812).
+ LOG(WARNING)
+ << "Using DistributedDataParallel with create_graph=True "
+ << " is not well-supported. The higher-order gradient will "
+ << " not be synchronized across ranks, and backpropagation "
+ << " through all_reduce operations will not occur. If you require "
+ << " DDP to work with higher-order gradients for your use case, "
+ << " please ping https://github.com/pytorch/pytorch/issues/63929";
+ auto div_result = at::mul(grad, wrapped);
+ bucket_view.copy_(div_result);
+ }
} else {
bucket_view.copy_(grad);
}
"Only NCCL and GLOO backend support DistributedDataParallel",
)
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+ def test_ddp_create_graph(self):
+ rank = self.rank
+ torch.cuda.set_device(rank)
+ net = torch.nn.parallel.DistributedDataParallel(
+ torch.nn.Linear(1, 1, bias=False).cuda(rank),
+ device_ids=[rank]
+ )
+ inp = torch.randn((2, 1), device=rank)
+ for _ in range(6):
+ loss = net(inp).sum()
+ # Verify DDP works with create_graph=True
+ loss.backward(create_graph=True)
+ # grad tensors should require grad.
+ self.assertTrue(
+ all([param.requires_grad for param in net.parameters()])
+ )
+
+ @sandcastle_skip_if(
+ BACKEND != "nccl" and BACKEND != "gloo",
+ "Only NCCL and GLOO backend support DistributedDataParallel",
+ )
+ @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
@skip_if_rocm
def test_DistributedDataParallel_non_default_stream(self):
stream = torch.cuda.Stream(self.rank)