From 29f096cc70cf9cc1a317ae7107228215b7dde60b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 11 Feb 2019 12:26:47 -0800 Subject: [PATCH] optionally zero infinite losses in CTCLoss (#16199) Summary: Here is a stab at implementing an option to zero out infinite losses (and NaN gradients). It might be nicer to move the zeroing to the respective kernels. The default is currently `False` to mimic the old behaviour, but I'd be half inclined to set the default to `True`, because the behaviour wasn't consistent between CuDNN and Native anyways and the NaN gradients aren't terribly useful. This topic seems to come up regularly, e.g. in #14335 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16199 Differential Revision: D14020462 Pulled By: ezyang fbshipit-source-id: 5ba8936c66ec6e61530aaf01175dc49f389ae428 --- aten/src/ATen/native/LossCTC.cpp | 34 +++++++++++++++++++----------- aten/src/ATen/native/cuda/LossCTC.cu | 27 +++++++++++++++--------- aten/src/ATen/native/cudnn/LossCTC.cpp | 5 +++-- aten/src/ATen/native/native_functions.yaml | 10 ++++----- test/test_nn.py | 23 ++++++++++++++++++++ tools/autograd/derivatives.yaml | 8 +++---- torch/nn/functional.py | 12 ++++++++--- torch/nn/modules/loss.py | 11 ++++++++-- 8 files changed, 92 insertions(+), 38 deletions(-) diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index b9b1c58..96db1a8 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -162,7 +162,7 @@ std::tuple ctc_loss_cpu_template(const Tensor& log_probs, const // b) collecting the per-activation characters for all s and wrapping the gradient (eq (16), the collection is the sum) template Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, - const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK) { + const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); using target_t = typename std::conditional::type; int64_t max_input_length = log_probs.size(0); @@ -207,6 +207,12 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_ #pragma omp parallel for for (int64_t b = 0; b < batch_size; b++) { + scalar_t nll = neg_log_likelihood.accessor()[b]; + if (zero_infinity && nll == std::numeric_limits::infinity()) { + grad.narrow(1, b, 1).zero_(); + continue; + } + auto log_probs_a = log_probs_a_global[b]; auto log_alpha_a = log_alpha_a_global[b]; auto log_beta_a = log_beta_a_global[b]; @@ -281,7 +287,6 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_ // now we wrap up the calculation by adding in the remaining items of eq (16) // this could be a great target for further vectorization. // grad is the output gradient, nll is the loss. Note that the likelihood -nll is the Z of eq (16) - scalar_t nll = neg_log_likelihood.accessor()[b]; scalar_t gr = grad_out.accessor()[b]; for (int64_t t = 0; t < input_length; t++) { // or go for the full thing? for (int64_t c = 0; c < num_labels; c++) { @@ -300,7 +305,8 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_ } // namespace -std::tuple ctc_loss_cpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) { +std::tuple ctc_loss_cpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) { + (void)zero_infinity; // only used for backwards return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss", [&] { if (targets.type().scalarType() == kLong) { return ctc_loss_cpu_template(log_probs, targets, input_lengths, target_lengths, BLANK); @@ -311,12 +317,12 @@ std::tuple ctc_loss_cpu(const Tensor& log_probs, const Tensor& t } Tensor ctc_loss_backward_cpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, - const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK) { + const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) { return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss_backward", [&] { if (targets.type().scalarType() == kLong) { - return ctc_loss_backward_cpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK); + return ctc_loss_backward_cpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity); } else { - return ctc_loss_backward_cpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK); + return ctc_loss_backward_cpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity); } }); } @@ -324,7 +330,7 @@ Tensor ctc_loss_backward_cpu(const Tensor& grad, const Tensor& log_probs, const // this wrapper function dispatches to the native and cudnn implementations and hides the alpha/grad from the user (by just returning the loss) // the gradient is implemented for _cudnn_ctc_loss (just in derivatives.yaml) and _ctc_loss and this function has automatic gradients // it also handles the reduction if desired -Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, int64_t reduction) { +Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) { auto& ctx = at::globalContext(); bool use_cudnn = @@ -343,15 +349,19 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu use_cudnn &= (input_lengths[b] == max_input_length); } for (int64_t b = 0; b < target_lengths.size(); b++) { - use_cudnn &= (target_lengths[b] <= 256); + // target length < 256 is documented, but we see illegal memory accesses when target lengths > input lengths for CuDNN + use_cudnn &= (target_lengths[b] <= 256) & (target_lengths[b] <= input_lengths[b]); } } Tensor res; if (use_cudnn) { - res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, ctx.deterministicCuDNN())); + res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, ctx.deterministicCuDNN(), zero_infinity)); } else { - res = std::get<0>(at::_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK)); + res = std::get<0>(at::_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, zero_infinity)); + if (zero_infinity) { + res = at::where(res == Scalar(std::numeric_limits::infinity()), at::zeros({}, res.options()), res); + } } if (reduction == Reduction::Mean) { auto target_lengths_t = at::tensor(target_lengths, res.options().device(at::Device(at::Device::Type::CPU)).dtype(kLong)).toType(res.type()); @@ -363,7 +373,7 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu } // Convenience function accepting Tensors -Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK, int64_t reduction) { +Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) { AT_CHECK(isIntegralType(input_lengths.type().scalarType()), "input_lenghts must be integral"); AT_CHECK(isIntegralType(target_lengths.type().scalarType()), "target_lenghts must be integral"); @@ -371,7 +381,7 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& in Tensor tlc = target_lengths.toType(kLong).toBackend(Backend::CPU).contiguous(); IntArrayRef il(ilc.data(), ilc.numel()); IntArrayRef tl(tlc.data(), tlc.numel()); - return at::native::ctc_loss(log_probs, targets, il, tl, BLANK, reduction); + return at::native::ctc_loss(log_probs, targets, il, tl, BLANK, reduction, zero_infinity); } } } // at::native diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index 55ec104..88d414e 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -379,7 +379,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, - int64_t batch_size, int64_t num_labels, int64_t BLANK) { + int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) { int64_t b = threadIdx.y + blockIdx.y * blockDim.y; int64_t s = threadIdx.x + blockIdx.x * blockDim.y; // note, this directly indexes into targets, no targets prime! @@ -401,6 +401,9 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da scalar_t nll = neg_log_likelihood_data[b]; scalar_t gr = grad_out_data[b * grad_out_batch_stride]; + if (zero_infinity && nll == INFINITY) + return; + for (int64_t t = 0; t < input_length; t++) { scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * target]; atomicAdd(&gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * target], @@ -428,7 +431,7 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, - int64_t batch_size, int64_t num_labels, int64_t BLANK) { + int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) { constexpr scalar_t neginf = -INFINITY; int64_t b = threadIdx.y + blockIdx.y * blockDim.y; @@ -466,7 +469,7 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, for (int64_t c = 0; c < num_labels; c++) { scalar_t& res = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * c]; - if (t < input_length) { + if (t < input_length && (! zero_infinity || nll != INFINITY)) { scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * c]; res = (std::exp(lp)-std::exp(res + nll - lp)) * gr; } @@ -480,7 +483,7 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, // We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward. template Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_probs, const Tensor& targets_, IntArrayRef input_lengths, IntArrayRef target_lengths, - const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK) { + const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) { constexpr scalar_t neginf = -INFINITY; using target_t = typename std::conditional::type; auto targets = targets_.toType(log_probs.type().toScalarType(target_scalar_type)); // to cuda if it isn't there already @@ -569,6 +572,9 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ ); // scale by output gradient (blanks and first summand of non-blanks) grad *= grad_out.view({1, batch_size, 1}); + if (zero_infinity) { + grad = at::where(neg_log_likelihood.view({1, batch_size, 1}) == Scalar(INFINITY), at::zeros({}, grad.options()), grad); + } // For the non-blank characters, we use a kernel to compute the subtrahend. // Again we might configure block and grid in a better way. @@ -591,7 +597,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), tg_batch_offsets.data(), tg_target_stride, - batch_size, num_labels, BLANK); + batch_size, num_labels, BLANK, zero_infinity); THCudaCheck(cudaGetLastError()); // catch launch errors } else { // small problem, use naive algorithm // Still no block/grid configuration guru... @@ -615,7 +621,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), tg_batch_offsets.data(), tg_target_stride, - batch_size, num_labels, BLANK); + batch_size, num_labels, BLANK, zero_infinity); THCudaCheck(cudaGetLastError()); // catch launch errors } return grad; @@ -623,7 +629,8 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ } // namespace -std::tuple ctc_loss_gpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) { +std::tuple ctc_loss_gpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) { + (void)zero_infinity; // only used for backward return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss", [&] { if (targets.type().scalarType() == kLong) { return ctc_loss_gpu_template(log_probs, targets, input_lengths, target_lengths, BLANK); @@ -634,12 +641,12 @@ std::tuple ctc_loss_gpu(const Tensor& log_probs, const Tensor& t } Tensor ctc_loss_backward_gpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, - const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK) { + const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) { return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss_backward", [&] { if (targets.type().scalarType() == kLong) { - return ctc_loss_backward_gpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK); + return ctc_loss_backward_gpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity); } else { - return ctc_loss_backward_gpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK); + return ctc_loss_backward_gpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity); } }); } diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index 1d3092f..bb7c1fe 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -13,7 +13,7 @@ namespace at { namespace native { // See Note [ATen preprocessor philosophy] -std::tuple _cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool deterministic) { +std::tuple _cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool deterministic, bool zero_infinity) { AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support"); } @@ -33,7 +33,8 @@ namespace { } // namespace -std::tuple _cudnn_ctc_loss(const Tensor& log_probs_t, const Tensor& targets_t, IntArrayRef input_lengths_, IntArrayRef target_lengths_, int64_t BLANK, bool deterministic) { +std::tuple _cudnn_ctc_loss(const Tensor& log_probs_t, const Tensor& targets_t, IntArrayRef input_lengths_, IntArrayRef target_lengths_, int64_t BLANK, bool deterministic, bool zero_infinity) { + (void)zero_infinity; // only used for backward CheckedFrom c = "cudnn_ctc_loss"; TensorArg log_probs { log_probs_t, "log_probs", 1 }; TensorArg targets { targets_t, "targets", 2 }; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7fc0aa2..7f2d1a2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -37,7 +37,7 @@ matches_jit_signature: True variants: function -- func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic) -> (Tensor, Tensor) +- func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) matches_jit_signature: True dispatch: CUDA: _cudnn_ctc_loss @@ -709,20 +709,20 @@ - func: cumprod(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True -- func: ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean) -> Tensor +- func: ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor matches_jit_signature: True # convenience function that converts to intlists for you -- func: ctc_loss(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean) -> Tensor +- func: ctc_loss(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor matches_jit_signature: True -- func: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0) -> (Tensor, Tensor) +- func: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) matches_jit_signature: True dispatch: CPU: ctc_loss_cpu CUDA: ctc_loss_gpu -- func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank) -> Tensor +- func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor matches_jit_signature: True dispatch: CPU: ctc_loss_backward_cpu diff --git a/test/test_nn.py b/test/test_nn.py index 79486d2..c3761bc 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4259,6 +4259,29 @@ class TestNN(NNTestCase): with self.assertRaises(RuntimeError): torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_CTCLoss_zero_infinity(self): + target_lengths = [60, 25, 20] + input_lengths = [50, 50, 50] + targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int) + log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2).requires_grad_() + res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, + reduction='sum', zero_infinity=True) + with torch.backends.cudnn.flags(enabled=False): + res2 = torch.nn.functional.ctc_loss(log_probs, targets.cuda().long(), input_lengths, target_lengths, + reduction='sum', zero_infinity=True) + res_cpu = torch.nn.functional.ctc_loss(log_probs.cpu(), targets.cpu(), input_lengths, target_lengths, + reduction='sum', zero_infinity=True) + + self.assertAlmostEqual(res2, res, delta=1e-4) + self.assertAlmostEqual(res_cpu, res.cpu(), delta=1e-4) + g1, = torch.autograd.grad(res, log_probs) + g2, = torch.autograd.grad(res2, log_probs) + g3, = torch.autograd.grad(res_cpu, log_probs) + self.assertAlmostEqual(g2, g3, delta=1e-4) + self.assertAlmostEqual(g1, g2, delta=1e-4) + self.assertTrue((g1 == g1).all().item()) # check that we don't have NaN + def test_RNN_cell_no_broadcasting(self): def test(cell_module, input, hx, input_size, hidden_size): cell = cell_module(input_size, hidden_size) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 56f50ad..641b55d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -250,8 +250,8 @@ - name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int64_t pad) self, weight, bias: conv_tbc_backward(grad, self, weight, bias, pad) -- name: _ctc_loss(Tensor log_probs, Tensor targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank) - log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank) +- name: _ctc_loss(Tensor log_probs, Tensor targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool zero_infinity) + log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) - name: det(Tensor self) self: det_backward(grad, self, result) @@ -1315,8 +1315,8 @@ output: -2 * output * grad * grad_output # cudnn -- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool deterministic) - log_probs: result1 +- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) + log_probs: "zero_infinity ? where(result0.unsqueeze(0).unsqueeze(2) == 0, zeros_like(result1), result1) : result1" - name: cudnn_convolution_transpose(Tensor self, Tensor weight, Tensor bias, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) self, weight, bias: cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 0308d21..635518a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1721,8 +1721,8 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): @weak_script def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, - reduction='mean'): - # type: (Tensor, Tensor, Tensor, Tensor, int, str) -> Tensor + reduction='mean', zero_infinity=False): + # type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor r"""The Connectionist Temporal Classification loss. See :class:`~torch.nn.CTCLoss` for details. @@ -1747,6 +1747,11 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: 'mean' + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. Example:: @@ -1757,7 +1762,8 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward() """ - return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction)) + return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), + zero_infinity) @weak_script diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 10c7a63..9d73b48 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1218,6 +1218,11 @@ class CTCLoss(_Loss): Lengths of the inputs (must each be :math:`\leq T`) target_lengths: Tuple or tensor of size :math:`(N)`. Lengths of the targets + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. Example:: @@ -1250,13 +1255,15 @@ class CTCLoss(_Loss): """ __constants__ = ['blank', 'reduction'] - def __init__(self, blank=0, reduction='mean'): + def __init__(self, blank=0, reduction='mean', zero_infinity=False): super(CTCLoss, self).__init__(reduction=reduction) self.blank = blank + self.zero_infinity = zero_infinity @weak_script_method def forward(self, log_probs, targets, input_lengths, target_lengths): - return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction) + return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, + self.zero_infinity) # TODO: L1HingeEmbeddingCriterion # TODO: MSECriterion weight -- 2.7.4