BUG Fixes regression for nllloss gradcheck (#64203)
authorThomas J. Fan <thomasjpfan@gmail.com>
Mon, 30 Aug 2021 22:03:40 +0000 (15:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 22:13:09 +0000 (15:13 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/64163

This PR includes the fix and the opinfo from https://github.com/pytorch/pytorch/pull/63854/ for non-regression testing.

cc albanD mruberry jbschlosser

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64203

Reviewed By: albanD

Differential Revision: D30647522

Pulled By: jbschlosser

fbshipit-source-id: 2974d299763505908fa93532aca2bd5d5b71f2e9

aten/src/ATen/native/cuda/Loss.cu
torch/testing/_internal/common_methods_invocations.py

index ac9c3c0..2087f19 100644 (file)
@@ -1,4 +1,5 @@
 #include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
 #include <ATen/NativeFunctions.h>
 #include <ATen/Dispatch.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
@@ -207,7 +208,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_1d(
     bool size_average,
     int n_classes,
     int64_t ignore_index) {
-  CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 & threadIdx.z == 0);
+  CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
 
   int t = static_cast<int>(*target);
   if (t != static_cast<int>(ignore_index)) {
@@ -263,7 +264,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
     *total_weight = static_cast<scalar_t>(total_weight_acc);
     if (size_average && nframe == 0) {
       // Mean reduction on empty tensors produces NaN
-      *output = std::numeric_limits<double>::quiet_NaN();
+      *output = std::numeric_limits<scalar_t>::quiet_NaN();
     } else if (size_average && total_weight_acc != 0) {
       *output = static_cast<scalar_t>(output_acc / total_weight_acc);
     } else {
@@ -286,7 +287,7 @@ void nll_loss_forward_out_cuda_template(
 
   auto weight_ = weight.defined() ? weight.contiguous() : weight;
 
-  if (reduction == Reduction::None & n_dims == 2) {
+  if (reduction == Reduction::None && n_dims == 2) {
     output.resize_({batch_size});
     if (batch_size == 0) {
       // This guards from unnecessary operations and launching CUDA kernel with
@@ -365,7 +366,8 @@ void nll_loss_forward_out_cuda_template(
               target.scalar_type(),
               "nll_loss_forward_reduce_cuda_kernel_2d_index",
               [&] {
-                nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, float, index_t>
+                using accscalar_t = at::acc_type<scalar_t, /*is_cuda*/true>;
+                nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, accscalar_t, index_t>
                     <<<1,
                        NLL_LOSS_THREADS,
                        0,
index 1349a29..52e8d73 100644 (file)
@@ -13,7 +13,7 @@ import numpy as np
 from torch._six import inf
 import collections.abc
 
-from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict
 
 from torch.testing import \
     (make_non_contiguous, floating_types, floating_types_and, complex_types,
@@ -5221,6 +5221,36 @@ def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs):
 
     return sample_inputs
 
+def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
+    batch_size, num_classes = shape = (2, 3)
+
+    input_shape_and_kwargs: List[Tuple[Tuple[int, ...], Dict[str, Any]]] = [
+        ((*shape, 1), dict()),
+        ((*shape, 1, 2), dict()),
+        ((*shape, 1, 2, 3), dict()),
+        (shape, dict(weight=make_tensor((num_classes,), device=device, dtype=dtype).abs())),
+        (shape, dict(ignore_index=num_classes // 2)),
+        (shape, dict(reduction="sum")),
+        (shape, dict(reduction="mean")),
+    ]
+
+    sample_inputs = []
+    for input_shape, kwargs in input_shape_and_kwargs:
+        input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad)
+
+        target = make_tensor(
+            (batch_size, *input_shape[2:]),
+            low=0,
+            high=num_classes,
+            device=device,
+            dtype=torch.long,
+            requires_grad=requires_grad
+        )
+
+        sample_inputs.append(SampleInput(input, args=(target,), kwargs=kwargs))
+
+    return sample_inputs
+
 foreach_unary_op_db: List[OpInfo] = [
     ForeachFuncInfo('exp'),
     ForeachFuncInfo('acos'),
@@ -9044,6 +9074,21 @@ op_db: List[OpInfo] = [
             SkipInfo('TestReductions', 'test_dim_none_keepdim'),
         ),
     ),
+    OpInfo(
+        "nn.functional.nll_loss",
+        ref=_NOTHING,
+        dtypesIfCPU=floating_types_and(torch.bfloat16),
+        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_nll_loss,
+        skips=(
+            SkipInfo(
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32,),
+            ),
+        ),
+    ),
 ]
 
 # Common operator groupings