From: Natalia Gimelshein Date: Tue, 19 Feb 2019 18:56:44 +0000 (-0800) Subject: reenable rand_like fusion when there is no broadcast (#16087) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1215 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=19117f6a0a7e30390e2b0efbba1dbc0f7018ff26;p=platform%2Fupstream%2Fpytorch.git reenable rand_like fusion when there is no broadcast (#16087) Summary: Reenables rand_like fusion if no tensor is broadcasted in the fusion group. This is a sufficient but not necessary condition for fused rand_like to produce correct results, and it has an unpleasant side effect of falling back to non-fused path if rand_like was optimistically included in the fusion group, but there is a broadcast in the fusion group not necessarily related to rand_like. E.g. before this PR, if the network had (biasAdd -> relu -> dropout), fuser could fuse biasAdd and relu, now it will try fusing the whole thing (if dropout is expressed via rand_like) and fall back every time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16087 Differential Revision: D13720232 Pulled By: zou3519 fbshipit-source-id: 1e19203bec4a59257bfc7078b054a19f00fab4ad --- diff --git a/test/test_jit.py b/test/test_jit.py index fbb60fa..6e6b4f5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11665,8 +11665,6 @@ class TestFuser(JitTestCase): (hy + cy).sum().backward() self.assertExpectedGraph(backward_graph(module), subname='backward') - # TODO: At some point we supported fusion of torch.rand_like but not anymore - @unittest.expectedFailure @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm @@ -11706,20 +11704,41 @@ class TestFuser(JitTestCase): ge = self.checkTrace(self.fn_test_relu, (x, y)) self.assertAllFused(ge.graph_for(x, y)) - @staticmethod - def fn_test_erf(x): - return F.relu(torch.erf(x) - torch.erfc(x)) - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm def test_erf_cuda(self): + def fn_test_erf(x): + return F.relu(torch.erf(x) - torch.erfc(x)) + x = torch.randn(4, 4, dtype=torch.float, device='cuda') - ge = self.checkTrace(self.fn_test_erf, (x,)) + ge = self.checkTrace(fn_test_erf, (x,)) self.assertAllFused(ge.graph_for(x)) x.requires_grad_(True) self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes")) + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @skipIfRocm + def test_rand_broadcast_cuda(self): + def fn_test_rand(x, y): + r = torch.rand_like(y) + return r * x + x + + x = torch.randn(4, 4, dtype=torch.float, device='cuda') + y = torch.randn(4, 4, dtype=torch.float, device='cuda') + script_f = torch.jit.script(fn_test_rand, (x, y)) + out = script_f(x, y) + self.assertAllFused(script_f.graph_for(x, y)) + x.requires_grad_(True) + out = script_f(x, y) + self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes")) + # test that broadcasting random produces correct results + x = torch.ones(4, 4, dtype=torch.float, device='cuda') + y = torch.ones(4, dtype=torch.float, device='cuda') + out = script_f(x, y) + self.assertEqual(out[0], out[1]) + @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") @enable_cpu_fuser def test_scalar(self): diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index b63f069..5e248f8 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -96,6 +96,7 @@ bool isDifferentiable(Node* n) { "aten::log10(Tensor self) -> Tensor", "aten::log1p(Tensor self) -> Tensor", "aten::log2(Tensor self) -> Tensor", + "aten::rand_like(Tensor self) -> Tensor", "aten::reciprocal(Tensor self) -> Tensor", "aten::remainder(Tensor self, Scalar other) -> Tensor", "aten::round(Tensor self) -> Tensor", @@ -533,6 +534,9 @@ class GradientHelper { "aten::type_as(Tensor self, Tensor other) -> Tensor")) { return {grads.at(0).type_as(inputs.at(0)), nullptr}; + } else if (node->matches("aten::rand_like(Tensor self) -> Tensor")) { + return {nullptr}; + } else if (node->matches( "aten::unsqueeze(Tensor self, int dim) -> Tensor")) { return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr}; diff --git a/torch/csrc/jit/fuser/executor.cpp b/torch/csrc/jit/fuser/executor.cpp index 2d3cd36..7c153ef 100644 --- a/torch/csrc/jit/fuser/executor.cpp +++ b/torch/csrc/jit/fuser/executor.cpp @@ -97,25 +97,44 @@ static c10::optional> canRunKernel( // (see above). // Note: Arguments are mutated by this call, although map_size is restored // to its original value. -static void expandArgs( +static bool expandArgs( const KernelSpec& spec, std::vector& args, - std::vector& map_size) { + std::vector& map_size, bool dry_run) { + bool has_broadcast = false; for (size_t i = 0; i < args.size(); ++i) { auto& arg = args[i]; const auto& pdesc = spec.inputChunks()[i]; if (pdesc.nSubTensors() == 1) { if (arg.sizes().equals(map_size)) continue; - arg = arg.expand(map_size); + if (!dry_run) { + arg = arg.expand(map_size); + has_broadcast = true; + } else { + return true; + } } else { map_size.at(pdesc.dim()) *= pdesc.nSubTensors(); if (!arg.sizes().equals(map_size)) { - arg = arg.expand(map_size); + if (!dry_run) { + arg = arg.expand(map_size); + has_broadcast = true; + } else { + return true; + } } map_size.at(pdesc.dim()) /= pdesc.nSubTensors(); } } + return has_broadcast; +} + +static bool shouldExpandArgs( + const KernelSpec& spec, + std::vector& args, + std::vector& map_size) { + return expandArgs(spec, args, map_size, /*dry_run=*/true); } // Note: assumes that inputs are 32-bit addressable @@ -326,7 +345,11 @@ bool runFusion(const int64_t key, Stack& stack) { // Tries to run fallback if map size can't be computed if (!maybe_map_size) return false; - expandArgs(spec, inputs, *maybe_map_size); + if (spec.hasRandom()) { + bool hasBroadcast = shouldExpandArgs(spec,inputs, *maybe_map_size); + if (hasBroadcast) return false; + } + expandArgs(spec, inputs, *maybe_map_size, /*dry_run=*/false); // Retrieves the kernel, compiling (and caching) if necessary ArgSpec arg_spec{inputs, device.index()}; diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 5ebb169..01b7eff 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -65,9 +65,7 @@ bool isSimpleMap(Node* node) { "aten::neg(Tensor self) -> Tensor", "aten::pow(Tensor self, Tensor exponent) -> Tensor", "aten::pow(Tensor self, Scalar exponent) -> Tensor", - // See https://github.com/pytorch/pytorch/issues/14674 and make sure you - // won't make the same mistake before you reenable this. - //"aten::rand_like(Tensor self) -> Tensor", + "aten::rand_like(Tensor self) -> Tensor", "aten::reciprocal(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",