reenable rand_like fusion when there is no broadcast (#16087)
authorNatalia Gimelshein <ngimelshein@nvidia.com>
Tue, 19 Feb 2019 18:56:44 +0000 (10:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 19 Feb 2019 19:12:25 +0000 (11:12 -0800)
Summary:
Reenables rand_like fusion if no tensor is broadcasted in the fusion group. This is a sufficient but not necessary condition for fused rand_like to produce correct results, and it has an unpleasant side effect of falling back to non-fused path if rand_like was optimistically included in the fusion group, but there is a broadcast in the fusion group not necessarily related to rand_like. E.g. before this PR, if the network had (biasAdd -> relu -> dropout), fuser could fuse biasAdd and relu, now it will try fusing the whole thing (if dropout is expressed via rand_like) and fall back every time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16087

Differential Revision: D13720232

Pulled By: zou3519

fbshipit-source-id: 1e19203bec4a59257bfc7078b054a19f00fab4ad

test/test_jit.py
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/fuser/executor.cpp
torch/csrc/jit/passes/graph_fuser.cpp

index fbb60fa..6e6b4f5 100644 (file)
@@ -11665,8 +11665,6 @@ class TestFuser(JitTestCase):
         (hy + cy).sum().backward()
         self.assertExpectedGraph(backward_graph(module), subname='backward')
 
-    # TODO: At some point we supported fusion of torch.rand_like but not anymore
-    @unittest.expectedFailure
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @skipIfRocm
@@ -11706,20 +11704,41 @@ class TestFuser(JitTestCase):
         ge = self.checkTrace(self.fn_test_relu, (x, y))
         self.assertAllFused(ge.graph_for(x, y))
 
-    @staticmethod
-    def fn_test_erf(x):
-        return F.relu(torch.erf(x) - torch.erfc(x))
-
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @skipIfRocm
     def test_erf_cuda(self):
+        def fn_test_erf(x):
+            return F.relu(torch.erf(x) - torch.erfc(x))
+
         x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        ge = self.checkTrace(self.fn_test_erf, (x,))
+        ge = self.checkTrace(fn_test_erf, (x,))
         self.assertAllFused(ge.graph_for(x))
         x.requires_grad_(True)
         self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes"))
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_rand_broadcast_cuda(self):
+        def fn_test_rand(x, y):
+            r = torch.rand_like(y)
+            return r * x + x
+
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        script_f = torch.jit.script(fn_test_rand, (x, y))
+        out = script_f(x, y)
+        self.assertAllFused(script_f.graph_for(x, y))
+        x.requires_grad_(True)
+        out = script_f(x, y)
+        self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
+        # test that broadcasting random produces correct results
+        x = torch.ones(4, 4, dtype=torch.float, device='cuda')
+        y = torch.ones(4, dtype=torch.float, device='cuda')
+        out = script_f(x, y)
+        self.assertEqual(out[0], out[1])
+
     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
     @enable_cpu_fuser
     def test_scalar(self):
index b63f069..5e248f8 100644 (file)
@@ -96,6 +96,7 @@ bool isDifferentiable(Node* n) {
       "aten::log10(Tensor self) -> Tensor",
       "aten::log1p(Tensor self) -> Tensor",
       "aten::log2(Tensor self) -> Tensor",
+      "aten::rand_like(Tensor self) -> Tensor",
       "aten::reciprocal(Tensor self) -> Tensor",
       "aten::remainder(Tensor self, Scalar other) -> Tensor",
       "aten::round(Tensor self) -> Tensor",
@@ -533,6 +534,9 @@ class GradientHelper {
                    "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
       return {grads.at(0).type_as(inputs.at(0)), nullptr};
 
+    } else if (node->matches("aten::rand_like(Tensor self) -> Tensor")) {
+      return {nullptr};
+
     } else if (node->matches(
                    "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
       return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
index 2d3cd36..7c153ef 100644 (file)
@@ -97,25 +97,44 @@ static c10::optional<std::vector<int64_t>> canRunKernel(
 // (see above).
 // Note: Arguments are mutated by this call, although map_size is restored
 // to its original value.
-static void expandArgs(
+static bool expandArgs(
     const KernelSpec& spec,
     std::vector<at::Tensor>& args,
-    std::vector<int64_t>& map_size) {
+    std::vector<int64_t>& map_size, bool dry_run) {
+  bool has_broadcast = false;
   for (size_t i = 0; i < args.size(); ++i) {
     auto& arg = args[i];
     const auto& pdesc = spec.inputChunks()[i];
     if (pdesc.nSubTensors() == 1) {
       if (arg.sizes().equals(map_size))
         continue;
-      arg = arg.expand(map_size);
+      if (!dry_run) {
+        arg = arg.expand(map_size);
+        has_broadcast = true;
+      } else {
+        return true;
+      }
     } else {
       map_size.at(pdesc.dim()) *= pdesc.nSubTensors();
       if (!arg.sizes().equals(map_size)) {
-        arg = arg.expand(map_size);
+        if (!dry_run) {
+          arg = arg.expand(map_size);
+          has_broadcast = true;
+        } else {
+          return true;
+        }
       }
       map_size.at(pdesc.dim()) /= pdesc.nSubTensors();
     }
   }
+  return has_broadcast;
+}
+
+static bool shouldExpandArgs(
+    const KernelSpec& spec,
+    std::vector<at::Tensor>& args,
+    std::vector<int64_t>& map_size) {  
+  return expandArgs(spec, args, map_size, /*dry_run=*/true);
 }
 
 // Note: assumes that inputs are 32-bit addressable
@@ -326,7 +345,11 @@ bool runFusion(const int64_t key, Stack& stack) {
   // Tries to run fallback if map size can't be computed
   if (!maybe_map_size)
     return false;
-  expandArgs(spec, inputs, *maybe_map_size);
+  if (spec.hasRandom()) {
+      bool hasBroadcast = shouldExpandArgs(spec,inputs, *maybe_map_size);
+      if (hasBroadcast) return false;
+  }
+  expandArgs(spec, inputs, *maybe_map_size, /*dry_run=*/false);
 
   // Retrieves the kernel, compiling (and caching) if necessary
   ArgSpec arg_spec{inputs, device.index()};
index 5ebb169..01b7eff 100644 (file)
@@ -65,9 +65,7 @@ bool isSimpleMap(Node* node) {
       "aten::neg(Tensor self) -> Tensor",
       "aten::pow(Tensor self, Tensor exponent) -> Tensor",
       "aten::pow(Tensor self, Scalar exponent) -> Tensor",
-      // See https://github.com/pytorch/pytorch/issues/14674 and make sure you
-      // won't make the same mistake before you reenable this.
-      //"aten::rand_like(Tensor self) -> Tensor",
+      "aten::rand_like(Tensor self) -> Tensor",
       "aten::reciprocal(Tensor self) -> Tensor",
       "aten::relu(Tensor self) -> Tensor",
       "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",