From 7078b2baf57bb6664309c680d4cc35bae23d6c9d Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Fri, 1 Feb 2019 07:59:56 -0800 Subject: [PATCH] Better bounds checks in ctcloss (#16269) Summary: Adds better bounds checks for target lengths in CTC loss, checks for integral types for target and prediction lengths, and adds tests for each, according to #15946 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16269 Differential Revision: D13847567 Pulled By: ezyang fbshipit-source-id: 5d7a975565e02baf78fe388813a1d1ef56dfb212 --- aten/src/ATen/native/LossCTC.cpp | 11 +++++++---- aten/src/ATen/native/cuda/LossCTC.cu | 8 ++++---- test/test_nn.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index 8e02ebf..c91aa49 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -53,16 +53,15 @@ std::tuple ctc_loss_cpu_template(const Tensor& log_probs, const AT_CHECK((int64_t) target_lengths.size() == batch_size, "target_lengths must be of size batch_size"); size_t tg_target_stride; - int64_t max_target_length; + int64_t max_target_length = 0; std::vector tg_batch_offsets(batch_size); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; - max_target_length = 0; for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets[i] = pos; pos += target_lengths[i]; if (max_target_length < target_lengths[i]) - max_target_length = target_lengths[i]; + max_target_length = target_lengths[i]; } tg_target_stride = targets.stride(0); checkSize(c, targets_arg, 0, pos); @@ -72,9 +71,10 @@ std::tuple ctc_loss_cpu_template(const Tensor& log_probs, const int64_t tg_batch_stride = targets.stride(0); for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets[i] = i * tg_batch_stride; + if (max_target_length < target_lengths[i]) + max_target_length = target_lengths[i]; } tg_target_stride = targets.stride(1); - max_target_length = targets.size(1); checkSize(c, targets_arg, 0, batch_size); AT_CHECK(targets.size(1) >= max_target_length, "Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg, @@ -364,6 +364,9 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntList input_le // Convenience function accepting Tensors Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK, int64_t reduction) { + AT_CHECK(isIntegralType(input_lengths.type().scalarType()), "input_lenghts must be integral"); + AT_CHECK(isIntegralType(target_lengths.type().scalarType()), "target_lenghts must be integral"); + Tensor ilc = input_lengths.toType(kLong).toBackend(Backend::CPU).contiguous(); Tensor tlc = target_lengths.toType(kLong).toBackend(Backend::CPU).contiguous(); IntList il(ilc.data(), ilc.numel()); diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index 7759b6c..d36185d 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -189,17 +189,16 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const int64_t lp_char_stride = log_probs.stride(2); int64_t tg_target_stride; - int64_t max_target_length; + int64_t max_target_length = 0; auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong)); auto tg_batch_offsets_data = tg_batch_offsets.data(); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; - max_target_length = 0; for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets_data[i] = pos; pos += target_lengths[i]; if (max_target_length < target_lengths[i]) - max_target_length = target_lengths[i]; + max_target_length = target_lengths[i]; } tg_target_stride = targets.stride(0); checkSize(c, targets_arg, 0, pos); @@ -209,9 +208,10 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const int64_t tg_batch_stride = targets.stride(0); for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets_data[i] = i * tg_batch_stride; + if (max_target_length < target_lengths[i]) + max_target_length = target_lengths[i]; } tg_target_stride = targets.stride(1); - max_target_length = targets.size(1); checkSize(c, targets_arg, 0, batch_size); AT_CHECK(targets.size(1) >= max_target_length, "Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg, diff --git a/test/test_nn.py b/test/test_nn.py index a263698..57ab17f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4238,6 +4238,35 @@ class TestNN(NNTestCase): self.assertEqual(res, expected) self.assertEqual(res2, res) + def test_CTCLoss_typechecks(self): + target_lengths = torch.tensor([30, 25, 20]) + input_lengths = torch.tensor([50, 50, 50]) + targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int) + log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2) + with self.assertRaises(RuntimeError): + _input_lengths = input_lengths.to(dtype=torch.float) + torch.nn.functional.ctc_loss(log_probs, targets, _input_lengths, target_lengths) + with self.assertRaises(RuntimeError): + target_lengths = target_lengths.to(dtype=torch.float) + torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_CTCLoss_lengthchecks_cuda(self): + target_lengths = [30, 25, 20] + input_lengths = [50, 50, 50] + targets = torch.randint(1, 15, (3, 29), dtype=torch.long, device='cuda') + log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2) + with self.assertRaises(RuntimeError): + torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) + + def test_CTCLoss_lengthchecks_cpu(self): + target_lengths = [30, 25, 20] + input_lengths = [50, 50, 50] + targets = torch.randint(1, 15, (3, 29), dtype=torch.int) + log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2) + with self.assertRaises(RuntimeError): + torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) + def test_RNN_cell_no_broadcasting(self): def test(cell_module, input, hx, input_size, hidden_size): cell = cell_module(input_size, hidden_size) -- 2.7.4