ENH Adds no_batch_dim for NLLLoss (#62651)
authorThomas J. Fan <thomasjpfan@gmail.com>
Tue, 24 Aug 2021 15:26:21 +0000 (08:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 15:27:27 +0000 (08:27 -0700)
Summary:
Towards https://github.com/pytorch/pytorch/issues/60585

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62651

Reviewed By: VitalyFedyunin

Differential Revision: D30303340

Pulled By: jbschlosser

fbshipit-source-id: 7ab478cf63bf6cd1f850cad5fd101e74a2cfe3f5

aten/src/ATen/native/LossNLL.cpp
aten/src/ATen/native/cuda/Loss.cu
torch/nn/modules/loss.py
torch/testing/_internal/common_nn.py

index 7c306c2..c7c65f7 100644 (file)
@@ -22,10 +22,12 @@ TORCH_META_FUNC(nll_loss_forward)
   TORCH_CHECK(
       self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
   TORCH_CHECK(
-      target.dim() == 1,
-      "1D target tensor expected, multi-target not supported");
+      target.dim() <= 1,
+      "0D or 1D target tensor expected, multi-target not supported");
+
+  auto no_batch_dim = self.dim() == 1  && target.dim() == 0;
   TORCH_CHECK(
-      self.size(0) == target.size(0),
+      no_batch_dim || (self.size(0) == target.size(0)),
       "size mismatch (got input: ",
       self.sizes(),
       ", target: ",
@@ -66,10 +68,12 @@ TORCH_META_FUNC(nll_loss_backward)
   TORCH_CHECK(
       self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
   TORCH_CHECK(
-      target.dim() == 1,
-      "1D target tensor expected, multi-target not supported");
+      target.dim() <= 1,
+      "0D or 1D target tensor expected, multi-target not supported");
+
+  auto no_batch_dim = self.dim() == 1  && target.dim() == 0;
   TORCH_CHECK(
-      self.size(0) == target.size(0),
+      no_batch_dim || (self.size(0) == target.size(0)),
       "size mismatch (got input: ",
       self.sizes(),
       ", target: ",
@@ -181,7 +185,6 @@ static void nll_loss_out_frame(
   const int64_t ndim = input.dim();
   TORCH_CHECK(ndim <= 2);
   const int64_t batch_size = ndim == 1 ? 1 : input.size(0);
-  TORCH_CHECK(target.size(0) == batch_size);
 
   constexpr int64_t cascade_sum_num_levels = 8;
   const int64_t level_power =
@@ -298,7 +301,11 @@ static void nll_loss_backward_out_frame(
   const auto n_dims = input.dim();
   const auto n_classes = input.size(-1);
 
-  auto target_acc = target.accessor<target_t, 1>();
+  auto target_ = target;
+  if (target.dim() == 0) {
+    target_ = target.unsqueeze(0);
+  }
+  auto target_acc = target_.accessor<target_t, 1>();
 
   auto weight_contiguous = optional_contiguous(weight);
   const scalar_t* weight_data = optional_data<scalar_t>(weight_contiguous);
@@ -349,7 +356,6 @@ static void nll_loss_backward_out_frame(
     auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
 
     const auto batch_size = input.size(0);
-    TORCH_CHECK(target.size(0) == batch_size);
 
     for (int64_t i = 0; i < batch_size; i++) {
       const auto cur_target = target_acc[i];
@@ -548,12 +554,12 @@ Tensor nll_loss_nd(
     const c10::optional<Tensor>& weight,
     int64_t reduction,
     int64_t ignore_index) {
-  if (self.dim() < 2) {
+  if (self.dim() < 1) {
     TORCH_CHECK_VALUE(
-        false, "Expected 2 or more dimensions (got ", self.dim(), ")");
+        false, "Expected 1 or more dimensions (got ", self.dim(), ")");
   }
 
-  if (self.sizes()[0] != target.sizes()[0]) {
+  if (self.dim() != 1 && self.sizes()[0] != target.sizes()[0]) {
     TORCH_CHECK_VALUE(
         false,
         "Expected input batch_size (",
@@ -566,7 +572,7 @@ Tensor nll_loss_nd(
   Tensor ret;
   Tensor input_ = self;
   Tensor target_ = target;
-  if (input_.dim() == 2) {
+  if (input_.dim() == 1 || input_.dim() == 2) {
     ret = at::nll_loss(input_, target_, weight, reduction, ignore_index);
   } else if (input_.dim() == 4) {
     ret = at::nll_loss2d(input_, target_, weight, reduction, ignore_index);
index d814eae..ac9c3c0 100644 (file)
@@ -468,7 +468,6 @@ void nll_loss_backward_out_cuda_template(
   int64_t n_dims = input.dim();
   int64_t n_classes = input.size(-1);
   int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
-  int64_t num_targets = target.size(0);
 
   auto weight_ = weight.defined() ? weight.contiguous() : weight;
 
index 7f39db4..03732b6 100644 (file)
@@ -164,10 +164,11 @@ class NLLLoss(_WeightedLoss):
             :attr:`reduction`. Default: ``'mean'``
 
     Shape:
-        - Input: :math:`(N, C)` where `C = number of classes`, or
+        - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or
           :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
           in the case of `K`-dimensional loss.
-        - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
+        - Target: :math:`(N)` or :math:`()`, where each value is
+          :math:`0 \leq \text{targets}[i] \leq C-1`, or
           :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
           K-dimensional loss.
         - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or
index 6b1bcf6..90024de 100644 (file)
@@ -97,6 +97,7 @@ def get_weight(m):
 # - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
 # - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
 
+
 module_tests = [
     dict(
         module_name='Linear',
@@ -1308,6 +1309,7 @@ def single_batch_reference_fn(input, parameters, module):
     with freeze_rng_state():
         return module(single_batch_input).squeeze(0)
 
+
 new_module_tests = [
     poissonnllloss_no_reduce_test(),
     bceloss_no_reduce_test(),
@@ -4055,6 +4057,7 @@ def kldivloss_reference(input, target, reduction='mean'):
         return result.sum() / result.size(0)
     return result
 
+
 def kldivloss_log_target_reference(input, target, reduction='mean'):
     result = torch.exp(target) * (target - input)
     if reduction == 'mean':
@@ -5182,6 +5185,7 @@ classification_criterion_no_batch = [
     ('HingeEmbeddingLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
     ('MultiLabelMarginLoss', lambda: torch.randn(4), lambda: torch.tensor([3, 0, -1, 1])),
     ('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
+    ('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)),
 ]
 classification_criterion_no_batch_extra_info: Dict[str, dict] = {
     'MultiLabelMarginLoss': {'check_gradgrad': False},
@@ -5580,6 +5584,7 @@ class ModuleTest(TestBase):
 
         self.test_noncontig(test_case, gpu_module, gpu_input_tuple)
 
+
 class InputVariableMixin(object):
     def _get_input(self):
         input = TestBase._get_input(self, False)  # type: ignore[arg-type]
@@ -5888,8 +5893,10 @@ class CriterionTest(InputVariableMixin, TestBase):  # type: ignore[misc]
         test_case.assertEqualIgnoreType(cpu_output, gpu_output,
                                         atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0)
 
-        cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
-        gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
+        cpu_gradInput = test_case._backward_criterion(
+            cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
+        gpu_gradInput = test_case._backward_criterion(
+            gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
         # dtype used to be able to be None, so set precision in this way instead of a precision map
         # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
         test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput,