From: Thomas J. Fan Date: Mon, 30 Aug 2021 22:03:40 +0000 (-0700) Subject: BUG Fixes regression for nllloss gradcheck (#64203) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~579 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a7ae73a2380c3e45394998d2d1d9bceb14f2ee55;p=platform%2Fupstream%2Fpytorch.git BUG Fixes regression for nllloss gradcheck (#64203) Summary: Fixes https://github.com/pytorch/pytorch/issues/64163 This PR includes the fix and the opinfo from https://github.com/pytorch/pytorch/pull/63854/ for non-regression testing. cc albanD mruberry jbschlosser Pull Request resolved: https://github.com/pytorch/pytorch/pull/64203 Reviewed By: albanD Differential Revision: D30647522 Pulled By: jbschlosser fbshipit-source-id: 2974d299763505908fa93532aca2bd5d5b71f2e9 --- diff --git a/aten/src/ATen/native/cuda/Loss.cu b/aten/src/ATen/native/cuda/Loss.cu index ac9c3c0..2087f19 100644 --- a/aten/src/ATen/native/cuda/Loss.cu +++ b/aten/src/ATen/native/cuda/Loss.cu @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -207,7 +208,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_1d( bool size_average, int n_classes, int64_t ignore_index) { - CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 & threadIdx.z == 0); + CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0); int t = static_cast(*target); if (t != static_cast(ignore_index)) { @@ -263,7 +264,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d( *total_weight = static_cast(total_weight_acc); if (size_average && nframe == 0) { // Mean reduction on empty tensors produces NaN - *output = std::numeric_limits::quiet_NaN(); + *output = std::numeric_limits::quiet_NaN(); } else if (size_average && total_weight_acc != 0) { *output = static_cast(output_acc / total_weight_acc); } else { @@ -286,7 +287,7 @@ void nll_loss_forward_out_cuda_template( auto weight_ = weight.defined() ? weight.contiguous() : weight; - if (reduction == Reduction::None & n_dims == 2) { + if (reduction == Reduction::None && n_dims == 2) { output.resize_({batch_size}); if (batch_size == 0) { // This guards from unnecessary operations and launching CUDA kernel with @@ -365,7 +366,8 @@ void nll_loss_forward_out_cuda_template( target.scalar_type(), "nll_loss_forward_reduce_cuda_kernel_2d_index", [&] { - nll_loss_forward_reduce_cuda_kernel_2d + using accscalar_t = at::acc_type; + nll_loss_forward_reduce_cuda_kernel_2d <<<1, NLL_LOSS_THREADS, 0, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1349a29..52e8d73 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13,7 +13,7 @@ import numpy as np from torch._six import inf import collections.abc -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict from torch.testing import \ (make_non_contiguous, floating_types, floating_types_and, complex_types, @@ -5221,6 +5221,36 @@ def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): return sample_inputs +def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + batch_size, num_classes = shape = (2, 3) + + input_shape_and_kwargs: List[Tuple[Tuple[int, ...], Dict[str, Any]]] = [ + ((*shape, 1), dict()), + ((*shape, 1, 2), dict()), + ((*shape, 1, 2, 3), dict()), + (shape, dict(weight=make_tensor((num_classes,), device=device, dtype=dtype).abs())), + (shape, dict(ignore_index=num_classes // 2)), + (shape, dict(reduction="sum")), + (shape, dict(reduction="mean")), + ] + + sample_inputs = [] + for input_shape, kwargs in input_shape_and_kwargs: + input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad) + + target = make_tensor( + (batch_size, *input_shape[2:]), + low=0, + high=num_classes, + device=device, + dtype=torch.long, + requires_grad=requires_grad + ) + + sample_inputs.append(SampleInput(input, args=(target,), kwargs=kwargs)) + + return sample_inputs + foreach_unary_op_db: List[OpInfo] = [ ForeachFuncInfo('exp'), ForeachFuncInfo('acos'), @@ -9044,6 +9074,21 @@ op_db: List[OpInfo] = [ SkipInfo('TestReductions', 'test_dim_none_keepdim'), ), ), + OpInfo( + "nn.functional.nll_loss", + ref=_NOTHING, + dtypesIfCPU=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_nll_loss, + skips=( + SkipInfo( + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + ), + ), ] # Common operator groupings