[nn] no batch dim support: CosineEmbeddingLoss (#64590)
authorkshitij12345 <kshitijkalambarkar@gmail.com>
Mon, 13 Sep 2021 17:44:04 +0000 (10:44 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 17:45:33 +0000 (10:45 -0700)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585

TODO
* [x] Add tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64590

Reviewed By: H-Huang

Differential Revision: D30900775

Pulled By: jbschlosser

fbshipit-source-id: d24e72787017e79afbf8f04a94901a290485b81a

aten/src/ATen/native/Loss.cpp
test/test_nn.py
torch/nn/modules/loss.py
torch/testing/_internal/common_nn.py

index 5bf6fee..6c4c21b 100644 (file)
@@ -30,13 +30,32 @@ DEFINE_DISPATCH(mse_stub);
 DEFINE_DISPATCH(mse_backward_stub);
 
 Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
+  auto targ_dim = target.dim();
   TORCH_CHECK(
-      target.dim() == 1,
-      "1D target tensor expected, multi-target not supported");
+      targ_dim == 1 || targ_dim == 0,
+      "0D or 1D target tensor expected, multi-target not supported");
+
+  if (targ_dim == 1) {
+    TORCH_CHECK(
+        input1.dim() == 2,
+        "1D target tensor expects 2D input tensors, but found inputs with sizes ",
+        input1.sizes(),
+        " and ",
+        input2.sizes(),
+        ".");
+  } else {
+    TORCH_CHECK(
+        input1.dim() == 1,
+        "0D target tensor expects 1D input tensors, but found inputs with sizes ",
+        input1.sizes(),
+        " and ",
+        input2.sizes(),
+        ".");
+  }
 
-  auto prod_sum = (input1 * input2).sum(1);
-  auto mag_square1 = (input1 * input1).sum(1) + EPSILON;
-  auto mag_square2 = (input2 * input2).sum(1) + EPSILON;
+  auto prod_sum = (input1 * input2).sum(targ_dim);
+  auto mag_square1 = (input1 * input1).sum(targ_dim) + EPSILON;
+  auto mag_square2 = (input2 * input2).sum(targ_dim) + EPSILON;
   auto denom = (mag_square1 * mag_square2).sqrt_();
   auto cos = prod_sum / denom;
 
index 1a6416d..f5b435f 100644 (file)
@@ -9480,7 +9480,7 @@ class TestNN(NNTestCase):
                          loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target,
                                                                    margin=0.5, reduction='none'))
 
-    def test_cosine_embedding_loss_invalid_target_shape(self):
+    def test_cosine_embedding_loss_invalid_shape(self):
         input1 = torch.randn(15, 10)
         input2 = torch.randn(15, 10)
         target = torch.randn(15, 1).sign()
@@ -9488,6 +9488,12 @@ class TestNN(NNTestCase):
         with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
             F.cosine_embedding_loss(input1, input2, target)
 
+        with self.assertRaisesRegex(RuntimeError, "1D target tensor expects 2D input tensors"):
+            F.cosine_embedding_loss(torch.randn(10), torch.randn(10), torch.randn(10))
+
+        with self.assertRaisesRegex(RuntimeError, "0D target tensor expects 1D input tensors"):
+            F.cosine_embedding_loss(torch.randn(2, 5), torch.randn(2, 5), torch.randn(()))
+
     def test_margin_ranking_loss_no_reduce(self):
         input1 = torch.randn(15).mul_(10).requires_grad_()
         input2 = torch.randn(15).mul_(10).requires_grad_()
index d72c614..e0989e5 100644 (file)
@@ -1236,9 +1236,9 @@ class CosineEmbeddingLoss(_Loss):
             specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
 
     Shape:
-        - Input1: :math:`(N, D)`, where `N` is the batch size and `D` is the embedding dimension.
-        - Input2: :math:`(N, D)`, same shape as Input1.
-        - Target: :math:`(N)`.
+        - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension.
+        - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1.
+        - Target: :math:`(N)` or :math:`()`.
         - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar.
     """
     __constants__ = ['margin', 'reduction']
index b22b6ab..f3cdc47 100644 (file)
@@ -5384,7 +5384,22 @@ def single_batch_reference_criterion_fn(*args):
     The output is squeezed to compare with the no-batch input.
     """
     criterion = args[-1]
-    single_batch_input_args = [input.unsqueeze(0) for input in args[:-1]]
+
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    def flatten(xs):
+        result = []
+        if isinstance(xs, (list, tuple)):
+            for x in xs:
+                result.extend(flatten(x))
+        else:
+            result.append(xs)
+        return result
+
+    single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])
 
     output = criterion(*single_batch_input_args)
     reduction = get_reduction(criterion)
@@ -5421,6 +5436,7 @@ classification_criterion_no_batch = [
     ('MultiLabelMarginLoss', lambda: torch.randn(4), lambda: torch.tensor([3, 0, -1, 1])),
     ('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
     ('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)),
+    ('CosineEmbeddingLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.tensor(1)),
 ]
 classification_criterion_no_batch_extra_info: Dict[str, dict] = {
     'MultiLabelMarginLoss': {'check_gradgrad': False},