[nn] TripletMarginLoss and PairwiseDistance : no batch dim (#64882)
authorkshitij12345 <kshitijkalambarkar@gmail.com>
Tue, 21 Sep 2021 14:21:15 +0000 (07:21 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 21 Sep 2021 14:29:48 +0000 (07:29 -0700)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64882

Reviewed By: malfet

Differential Revision: D31055577

Pulled By: jbschlosser

fbshipit-source-id: 2f0a5a08619b672026b48a78bc7d83a6dccba0bf

aten/src/ATen/native/Distance.cpp
aten/src/ATen/native/Loss.cpp
test/test_nn.py
torch/nn/modules/distance.py
torch/nn/modules/loss.py
torch/testing/_internal/common_nn.py

index b79c3e9..7974840 100644 (file)
@@ -14,7 +14,12 @@ DEFINE_DISPATCH(cdist_stub);
 DEFINE_DISPATCH(cdist_backward_stub);
 
 Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) {
-  return at::norm(x1 - x2 + eps, p, 1, keepdim);
+  // Since either x1 or x2 could be broadcasted
+  auto x1_dim = x1.dim();
+  auto x2_dim = x2.dim();
+  auto output_dim = x1_dim > x2_dim ? x1_dim : x2_dim;
+  auto innermost_dim = output_dim - 1;
+  return at::norm(x1 - x2 + eps, p, innermost_dim, keepdim);
 }
 
 // This is to guarantee that the contiguous memory is passed to the backward pass
index 6c4c21b..4054655 100644 (file)
@@ -79,6 +79,18 @@ Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double mar
 
 Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const Tensor& negative, double margin,
                            double p, double eps, bool swap, int64_t reduction) {
+  auto a_dim = anchor.dim();
+  auto p_dim = positive.dim();
+  auto n_dim = negative.dim();
+  TORCH_CHECK(
+      a_dim == p_dim && p_dim == n_dim,
+      "All inputs should have same dimension but got ",
+      a_dim,
+      "D, ",
+      p_dim,
+      "D and ",
+      n_dim,
+      "D inputs.")
   auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
   auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
   if (swap) {
index f324350..92357d9 100644 (file)
@@ -9594,6 +9594,21 @@ class TestNN(NNTestCase):
         self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
                          loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
 
+    def test_triplet_margin_loss_invalid(self):
+        input1 = torch.randn(5, 10, requires_grad=True)
+        input2 = torch.randn(5, 10, requires_grad=True)
+        input3 = torch.randn(5, 10, requires_grad=True)
+        input_1d = torch.randn(10, requires_grad=True)
+
+        with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
+            F.triplet_margin_loss(input1, input2, input_1d)
+
+        with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
+            F.triplet_margin_loss(input1, input_1d, input3)
+
+        with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
+            F.triplet_margin_loss(input_1d, input2, input3)
+
     def test_pointwise_loss_target_grad_none_reduction(self):
         i = torch.randn(5, 10)
         t = torch.randn(5, 10, requires_grad=True)
index 29501e9..00513ac 100644 (file)
@@ -6,7 +6,7 @@ from torch import Tensor
 
 class PairwiseDistance(Module):
     r"""
-    Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
+    Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
 
     .. math ::
         \Vert x \Vert _p = \left( \sum_{i=1}^n  \vert x_i \vert ^ p \right) ^ {1/p}.
@@ -18,9 +18,10 @@ class PairwiseDistance(Module):
         keepdim (bool, optional): Determines whether or not to keep the vector dimension.
             Default: False
     Shape:
-        - Input1: :math:`(N, D)` where `D = vector dimension`
-        - Input2: :math:`(N, D)`, same shape as the Input1
-        - Output: :math:`(N)`. If :attr:`keepdim` is ``True``, then :math:`(N, 1)`.
+        - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension`
+        - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1
+        - Output: :math:`(N)` or :math:`()` based on input dimension.
+                If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension.
     Examples::
         >>> pdist = nn.PairwiseDistance(p=2)
         >>> input1 = torch.randn(100, 128)
index e0989e5..6d295aa 100644 (file)
@@ -1436,9 +1436,9 @@ class TripletMarginLoss(_Loss):
             specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
 
     Shape:
-        - Input: :math:`(N, D)` where :math:`D` is the vector dimension.
-        - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
-          otherwise.
+        - Input: :math:`(N, D)` or :math`(D)` where :math:`D` is the vector dimension.
+        - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and
+                  input shape is :math`(N, D)`; a scalar otherwise.
 
     Examples::
 
index 58f30d8..74895bb 100644 (file)
@@ -1305,9 +1305,15 @@ def single_batch_reference_fn(input, parameters, module):
     The module is passed the input and target in batched form with a single item.
     The output is squeezed to compare with the no-batch input.
     """
-    single_batch_input = input.unsqueeze(0)
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    single_batch_input = unsqueeze_inp(input)
+    single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
     with freeze_rng_state():
-        return module(single_batch_input).squeeze(0)
+        return module(*single_batch_input).squeeze(0)
 
 
 new_module_tests = [
@@ -3944,6 +3950,33 @@ new_module_tests = [
         pickle=False,
     ),
     dict(
+        module_name='PairwiseDistance',
+        input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+    ),
+    dict(
+        module_name='PairwiseDistance',
+        input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
+        desc='broadcast_lhs'
+    ),
+    dict(
+        module_name='PairwiseDistance',
+        input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
+        desc='broadcast_rhs'
+    ),
+    dict(
+        module_name='PairwiseDistance',
+        constructor_args=(1.5, 1e-05, True),
+        cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
+        input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+        desc='with_non_default_args',
+    ),
+    dict(
+        module_name='PairwiseDistance',
+        input_fn=lambda: (torch.randn(8), torch.randn(8)),
+        reference_fn=single_batch_reference_fn,
+        desc='no_batch_dim',
+    ),
+    dict(
         module_name='TransformerEncoderLayer',
         constructor_args=(4, 2, 16, 0.0),
         cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
@@ -5445,6 +5478,8 @@ classification_criterion_no_batch = [
     ('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
     ('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)),
     ('CosineEmbeddingLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.tensor(1)),
+    # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
+    ('TripletMarginLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.randn(9)),
 ]
 classification_criterion_no_batch_extra_info: Dict[str, dict] = {
     'MultiLabelMarginLoss': {'check_gradgrad': False},