Fixes error when too many parameters are passed to fused cuda kernel (#18063)
authorRoy Ju <rju@nvidia.com>
Wed, 10 Apr 2019 05:29:33 +0000 (22:29 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 05:37:09 +0000 (22:37 -0700)
Summary:
Bug fix for https://github.com/pytorch/pytorch/issues/15043, where a large fusion in JIT with a large number of kernel arguments, which exceeds the limit allowed by nvrtc on a cuda device.
  The fix is to check the number of arguments before a cuda kernel is generated. If the number exceeds the limit, take the runFallBack() path.
  Add a reduced test from the original issue to keep the test time low. The test would fail without this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18063

Differential Revision: D14691401

Pulled By: soumith

fbshipit-source-id: b98829bc89ed7724e91eda82ae3a5a1151af721a

test/test_jit.py
torch/csrc/jit/fuser/compiler.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/graph_fuser.h

index 59e360d..f2b5d44 100644 (file)
@@ -621,6 +621,38 @@ class FooToPickle(torch.nn.Module):
 
 
 class TestJit(JitTestCase):
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    def test_large_nbr_kernel_args(self):
+        class Recurrence(nn.Module):
+            def __init__(self, seq_len):
+                super(Recurrence, self).__init__()
+                self.seq_len = seq_len
+
+            def forward(self, input):
+                input = input.transpose(0, 1)
+
+                # Main loop
+                output = []
+                for i in range(self.seq_len):
+                    b = input[i] * 2
+                    output.append(b)
+
+                output = torch.cat(output, 0).view(input.size(0), *output[0].size())
+                output = output.transpose(0, 1)
+                return output
+
+        input_size = 8
+        batch_size = 2
+        seq_len = 130
+
+        rec = Recurrence(seq_len)
+        input = torch.rand(batch_size, seq_len, input_size)
+
+        torch.cuda.set_device(0)
+        rec = rec.cuda()
+        input = input.cuda()
+
+        traced_rec = torch.jit.trace(rec, (input))
 
     @unittest.skip("Requires a lot of RAM")
     def test_big(self):
index ebfa45e..7f0eed5 100644 (file)
@@ -332,8 +332,12 @@ std::shared_ptr<FusedKernel> compileKernel(
     }
   }
 
-  const std::string name = "kernel_" + std::to_string(next_kernel_id++);
+  // Have checked the limit at graph_fuser. Assert nothing else changing that.
+  AT_ASSERT((flat_inputs.size() + flat_outputs.size()) <=
+            fusion_kernel_args_limit);
+
   const bool use_cuda = device.is_cuda();
+  const std::string name = "kernel_" + std::to_string(next_kernel_id++);
   std::string code =
       generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
   const FusedKernelConstructor& kernel_ctor =
index ad1a835..a8ae6d3 100644 (file)
@@ -211,7 +211,12 @@ struct GraphFuser {
       return false;
     if (!node->is_constant(attr::dim))
       return false;
+
     auto tensors_node = node->namedInput(attr::tensors)->node();
+    if( (tensors_node->inputs().size() + node->outputs().size()) >
+        fusion_kernel_args_limit ) {
+      return false;
+    }
     if (tensors_node->kind() != prim::ListConstruct)
       return false;
     // NB: Note that technically other uses of the list aren't a big problem for
@@ -473,6 +478,13 @@ struct GraphFuser {
     if (!shouldFuse) {
       return at::nullopt;
     }
+
+    if( (consumer->inputs().size() + consumer->outputs().size() +
+         producer->node()->inputs().size() + producer->node()->outputs().size()) >
+        fusion_kernel_args_limit ) {
+        return at::nullopt;
+    }
+
     if (producer->node()->kind() == aten::_grad_sum_to_size &&
         consumer->kind() == prim::FusionGroup) {
       // check that we will be able to move the _grad_sum_to_size to be fused
@@ -1062,6 +1074,16 @@ struct GraphFuser {
             producer->node(), before_check)) {
       return false;
     }
+
+    // If the number of kernel args could exceed the limit, skip.
+    if ((before_check->inputs().size() +
+         before_check->outputs().size() +
+         producer->node()->inputs().size() +
+         producer->node()->outputs().size())
+        > fusion_kernel_args_limit) {
+      return false;
+    }
+
     // Fusion groups can be merged with concat's group if and only if
     // - the value they produce isn't already coming from a concat and
     // - the fusion group does not contain GradSumToSize
index bab187a..7c66896 100644 (file)
@@ -5,6 +5,15 @@
 namespace torch {
 namespace jit {
 
+// nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
+// The specific limit is a function of constant memory size, amount available
+// to pass arguments, and some implementation dependence. Select a safe
+// limit here.
+//   This limit is also applied to other devices in the fuser, because we
+// don't consider a kernel with such a large number of arguments would be
+// profitable.
+constexpr size_t fusion_kernel_args_limit = 128;
+
 // NB: Be sure to run DCE before fusion, because dead instructions
 // can prevent fusion opportunities from being exploited.
 // On Windows will noop, NYI