From a50ba7e2384b96e4b52fbec39f644e29a76a0a4f Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Tue, 19 Mar 2019 10:20:06 -0700 Subject: [PATCH] specialized CUDA impl for dropout in AD (#17756) Summary: In aten we have a _fused_dropout implementation for CUDA case. As ngimel suggested if we discard it in JIT AD, it hurts performance. It doesn't seem ideal to include backend specific implementation in AD, but this is helpful to prevent performance regression atm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17756 Differential Revision: D14368999 Pulled By: ailzhang fbshipit-source-id: 9a371c5020f630e8f6e496849ec9772b6f196169 --- test/test_jit.py | 21 ++++++++++++++ torch/csrc/jit/passes/shape_analysis.cpp | 50 +++++++++++++++++++++++++------- torch/csrc/jit/symbolic_script.cpp | 27 ++++++++++++++--- 3 files changed, 83 insertions(+), 15 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index d55c66a..d314ca1 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1341,6 +1341,27 @@ class TestJit(JitTestCase): m = self.createScriptModuleFromGraph(trace) self.assertEqual(outputs, m(*inputs)) + @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA") + def test_dropout_cuda(self): + # Dropout AD is dispatched to _fused_dropout in CUDA case, + # which is not included in TestJitGeneratedFunctional + x = torch.ones(4, 4).cuda().requires_grad_() + + @torch.jit.script + def func(x): + return torch.nn.functional.dropout(x) + + with freeze_rng_state(): + out_ref = torch.nn.functional.dropout(x) + grad_ref = torch.autograd.grad(out_ref.sum(), x) + + with freeze_rng_state(): + out = func(x) + grad = torch.autograd.grad(out.sum(), x) + + self.assertEqual(out, out_ref) + self.assertEqual(grad, grad_ref) + def test_conv(self): x = torch.ones(20, 16, 50, 40) trace, outputs, inputs = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True) diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index fb4931f..6d771b4 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -48,6 +48,19 @@ bool isValidReturnForRunning(Value* v) { v->type()->isSubtypeOf(NumberType::get()); } +bool containsTensorType(const TypePtr& t) { + auto n_contained = t->containedTypes().size(); + if (n_contained == 1) { + return t->containedTypes().at(0)->isSubtypeOf(TensorType::get()); + } else if (n_contained > 1) { + return std::any_of( + t->containedTypes().begin(), + t->containedTypes().end(), + containsTensorType); + } + return false; +} + class ShapePropagator { public: explicit ShapePropagator(std::shared_ptr graph) : aliasDb_(graph) { @@ -298,6 +311,18 @@ class ShapePropagator { return true; } + // If there's no Tensor in outputs, e.g float / float, + // we don't need to propagate shape. + bool DoesntRefineOutputs(Node* node) { + auto outputs = node->outputs(); + for (auto& out : outputs) { + if (containsTensorType(out->type())) { + return false; + } + } + return true; + } + bool PropagateShapeOnNodeByRunningIt(Node* node) { if (!canPropagateShapeByRunningIt(node)) return false; @@ -534,6 +559,10 @@ class ShapePropagator { return; } + if (DoesntRefineOutputs(node)) { + return; + } + if (PropagateShapeOnNodeByRunningIt(node)) { return; } @@ -1074,26 +1103,25 @@ class ShapePropagator { at::optional maybe_layout_option = node->get(attr::layout); if (!maybe_layout_option) return {}; - auto layout = (maybe_layout_option->isNone() - ? at::kStrided - : maybe_layout_option->toLayout()); + auto layout = + (maybe_layout_option->isNone() ? at::kStrided + : maybe_layout_option->toLayout()); at::optional maybe_device_option = node->get(attr::device); if (!maybe_device_option) return {}; - auto device = (maybe_device_option->isNone() - ? at::kCPU - : maybe_device_option->toDevice()); + auto device = + (maybe_device_option->isNone() ? at::kCPU + : maybe_device_option->toDevice()); at::optional maybe_dtype_option = node->get(attr::dtype); if (!maybe_dtype_option) return {}; - auto dtype = (maybe_dtype_option->isNone() - ? at::kFloat - : maybe_dtype_option->toScalarType()); + auto dtype = + (maybe_dtype_option->isNone() ? at::kFloat + : maybe_dtype_option->toScalarType()); - return {DimensionedTensorType::create( - dtype, device, dim)}; + return {DimensionedTensorType::create(dtype, device, dim)}; }; // Requirements: diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 3254945..98e5a74 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -691,15 +691,34 @@ const std::vector functions = { return output, backward + def AD_fused_dropout_backward(grad, + mask, + p1m: float): + p1r = 1. / p1m + if grad.requires_grad: + grad_input = grad * (mask.type_as(grad) * p1r) + else: + grad_input = torch._masked_scale(grad, mask, p1r) + return grad_input + def dropout(input, p: float, train: bool): - mask = torch.empty_like(input) - mask.bernoulli_(1 - p) - res = mask * input / (1.0 - p) + use_cuda = input.is_cuda + # CUDA has a fused dropout implementation + p1m = 1. - p + if use_cuda: + res, mask = torch._fused_dropout(input, p1m) + else: + mask = torch.empty_like(input) + mask.bernoulli_(p1m) + res = mask * input / p1m def backward(grad_output): - grad_input = grad_output * mask / (1.0 - p) + if use_cuda: + grad_input = AD_fused_dropout_backward(grad_output, mask, p1m) + else: + grad_input = grad_output * mask / p1m return grad_input, None, None return res, backward -- 2.7.4