[CUDA graphs] hotfix for test_graph_ (#64339)
authorMichael Carilli <mcarilli@gmail.com>
Wed, 1 Sep 2021 04:43:25 +0000 (21:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 05:34:10 +0000 (22:34 -0700)
Summary:
Graphed workloads that try to capture a full backward pass must do warmup on a non-default stream. If warmup happens on the default stream, AccumulateGrad functions might tag themselves to run on the default stream, and therefore won't be capturable.

ngimel and I suspect some test_cuda.py tests run with the default stream as the ambient stream, which breaks `test_graph_grad_scaling` because `test_graph_grad_scaling` does warmup on the ambient stream _assuming_ the ambient stream is a non-default stream.

This PR explicitly sets a side stream for the warmup in `test_graph_grad_scaling`, which is what I should have done all along because it's what the new documentation recommends.

I pushed the PR branch straight to the main pytorch repo because we need to run ci-all on it, and I'm not sure what the requirements are these days.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64339

Reviewed By: mruberry

Differential Revision: D30690711

Pulled By: ngimel

fbshipit-source-id: 91ad75f46a11f311e25bc468ea184e22acdcc25a

test/test_cuda.py

index 6f742ec..33dbade 100644 (file)
@@ -3683,8 +3683,13 @@ torch.cuda.synchronize()
         static_grad = torch.ones_like(weight)
 
         # warmup
-        loss = (weight.half() * static_input).sum()
-        scaler.scale(loss).backward()
+        s = torch.cuda.Stream()
+        s.wait_stream(torch.cuda.current_stream())
+        with torch.cuda.stream(s):
+            loss = (weight.half() * static_input).sum()
+            scaler.scale(loss).backward()
+        torch.cuda.current_stream().wait_stream(s)
+
         opt.zero_grad(set_to_none=True)
 
         # capture