Summary:
Bug fix for https://github.com/pytorch/pytorch/issues/15043, where a large fusion in JIT with a large number of kernel arguments, which exceeds the limit allowed by nvrtc on a cuda device.
The fix is to check the number of arguments before a cuda kernel is generated. If the number exceeds the limit, take the runFallBack() path.
Add a reduced test from the original issue to keep the test time low. The test would fail without this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18063
Differential Revision:
D14691401
Pulled By: soumith
fbshipit-source-id:
b98829bc89ed7724e91eda82ae3a5a1151af721a
class TestJit(JitTestCase):
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ def test_large_nbr_kernel_args(self):
+ class Recurrence(nn.Module):
+ def __init__(self, seq_len):
+ super(Recurrence, self).__init__()
+ self.seq_len = seq_len
+
+ def forward(self, input):
+ input = input.transpose(0, 1)
+
+ # Main loop
+ output = []
+ for i in range(self.seq_len):
+ b = input[i] * 2
+ output.append(b)
+
+ output = torch.cat(output, 0).view(input.size(0), *output[0].size())
+ output = output.transpose(0, 1)
+ return output
+
+ input_size = 8
+ batch_size = 2
+ seq_len = 130
+
+ rec = Recurrence(seq_len)
+ input = torch.rand(batch_size, seq_len, input_size)
+
+ torch.cuda.set_device(0)
+ rec = rec.cuda()
+ input = input.cuda()
+
+ traced_rec = torch.jit.trace(rec, (input))
@unittest.skip("Requires a lot of RAM")
def test_big(self):
}
}
- const std::string name = "kernel_" + std::to_string(next_kernel_id++);
+ // Have checked the limit at graph_fuser. Assert nothing else changing that.
+ AT_ASSERT((flat_inputs.size() + flat_outputs.size()) <=
+ fusion_kernel_args_limit);
+
const bool use_cuda = device.is_cuda();
+ const std::string name = "kernel_" + std::to_string(next_kernel_id++);
std::string code =
generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
const FusedKernelConstructor& kernel_ctor =
return false;
if (!node->is_constant(attr::dim))
return false;
+
auto tensors_node = node->namedInput(attr::tensors)->node();
+ if( (tensors_node->inputs().size() + node->outputs().size()) >
+ fusion_kernel_args_limit ) {
+ return false;
+ }
if (tensors_node->kind() != prim::ListConstruct)
return false;
// NB: Note that technically other uses of the list aren't a big problem for
if (!shouldFuse) {
return at::nullopt;
}
+
+ if( (consumer->inputs().size() + consumer->outputs().size() +
+ producer->node()->inputs().size() + producer->node()->outputs().size()) >
+ fusion_kernel_args_limit ) {
+ return at::nullopt;
+ }
+
if (producer->node()->kind() == aten::_grad_sum_to_size &&
consumer->kind() == prim::FusionGroup) {
// check that we will be able to move the _grad_sum_to_size to be fused
producer->node(), before_check)) {
return false;
}
+
+ // If the number of kernel args could exceed the limit, skip.
+ if ((before_check->inputs().size() +
+ before_check->outputs().size() +
+ producer->node()->inputs().size() +
+ producer->node()->outputs().size())
+ > fusion_kernel_args_limit) {
+ return false;
+ }
+
// Fusion groups can be merged with concat's group if and only if
// - the value they produce isn't already coming from a concat and
// - the fusion group does not contain GradSumToSize
namespace torch {
namespace jit {
+// nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
+// The specific limit is a function of constant memory size, amount available
+// to pass arguments, and some implementation dependence. Select a safe
+// limit here.
+// This limit is also applied to other devices in the fuser, because we
+// don't consider a kernel with such a large number of arguments would be
+// profitable.
+constexpr size_t fusion_kernel_args_limit = 128;
+
// NB: Be sure to run DCE before fusion, because dead instructions
// can prevent fusion opportunities from being exploited.
// On Windows will noop, NYI