Allow TransformerEncoder and TransformerDecoder to accept 0-dim batch sized tensors...
authorSameer Deshmukh <sameer.deshmukh93@gmail.com>
Fri, 13 Aug 2021 23:08:01 +0000 (16:08 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 23:11:57 +0000 (16:11 -0700)
Summary:
This issue fixes a part of https://github.com/pytorch/pytorch/issues/12013, which is summarized concretely in  https://github.com/pytorch/pytorch/issues/38115.

This PR allows TransformerEncoder and Decoder (alongwith the inner `Layer` classes) to accept inputs with 0-dimensional batch sizes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62800

Reviewed By: VitalyFedyunin

Differential Revision: D30303240

Pulled By: jbschlosser

fbshipit-source-id: 8f8082a6f2a9f9d7ce0b22a942d286d5db62bd12

test/test_nn.py
torch/nn/functional.py

index a1da171..6ccf81e 100644 (file)
@@ -12398,6 +12398,20 @@ class TestNNDeviceType(NNTestCase):
                 self.assertEqual(p.grad, torch.zeros_like(p.grad))
         self.assertEqual(inp.grad, torch.zeros_like(inp))
 
+    def _test_module_empty_inputs(self, module, inputs):
+        for _inp in inputs:
+            _inp.requires_grad_(True)
+        out = module(*inputs)
+        gO = torch.rand_like(out)
+        out.backward(gO)
+
+        for p in module.parameters():
+            if p.requires_grad:
+                self.assertEqual(p.grad, torch.zeros_like(p.grad))
+
+        for _inp in inputs:
+            self.assertEqual(_inp.grad, torch.zeros_like(_inp))
+
     @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
                      "Scipy v1.0 and/or numpy not found")
     @tf32_on_and_off()
@@ -13226,6 +13240,50 @@ class TestNNDeviceType(NNTestCase):
         self.assertEqual(inp2.grad, torch.zeros_like(inp2))
 
     @onlyOnCPUAndCUDA
+    def test_TransformerEncoderLayer_empty(self, device):
+        for batch_first, input_shape in [(True, (0, 10, 512)),
+                                         (False, (10, 0, 512))]:
+            input = torch.rand(*input_shape, device=device)
+            encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
+            self._test_module_empty_input(encoder_layer, input, check_size=False)
+
+    @onlyOnCPUAndCUDA
+    def test_TransformerEncoder_empty(self, device):
+        for batch_first, input_shape in [(True, (0, 10, 512)),
+                                         (False, (10, 0, 512))]:
+            input = torch.rand(*input_shape, device=device)
+            encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
+            transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6).to(device)
+            self._test_module_empty_input(transformer_encoder, input, check_size=False)
+
+    @onlyOnCPUAndCUDA
+    def test_TransformerDecoderLayer_empty(self, device):
+        for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
+                                                     (False, (10, 0, 512), (20, 0, 512))]:
+            memory = torch.rand(*memory_shape, device=device)
+            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device)
+            decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
+            self._test_module_empty_inputs(decoder_layer, [tgt, memory])
+
+    @onlyOnCPUAndCUDA
+    def test_TransformerDecoder_empty(self, device):
+        for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
+                                                     (False, (10, 0, 512), (20, 0, 512))]:
+            memory = torch.rand(*memory_shape, device=device)
+            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device)
+            decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
+            transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6).to(device)
+            self._test_module_empty_inputs(transformer_decoder, [tgt, memory])
+
+    @onlyOnCPUAndCUDA
+    def test_Transformer_empty(self, device):
+        for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]:
+            transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12).to(device)
+            src = torch.rand(*src_shape, requires_grad=True, device=device)
+            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device)
+            self._test_module_empty_inputs(transformer_model, [src, tgt])
+
+    @onlyOnCPUAndCUDA
     @dtypes(torch.float32, torch.complex64)
     def test_ReflectionPad_empty(self, device, dtype):
         for mod, inp in [
index bb08427..5212586 100644 (file)
@@ -5037,7 +5037,7 @@ def multi_head_attention_forward(
     #
     q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
     if static_k is None:
-        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+        k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
     else:
         # TODO finish disentangling control flow so we don't do in-projections when statics are passed
         assert static_k.size(0) == bsz * num_heads, \
@@ -5046,7 +5046,7 @@ def multi_head_attention_forward(
             f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
         k = static_k
     if static_v is None:
-        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+        v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
     else:
         # TODO finish disentangling control flow so we don't do in-projections when statics are passed
         assert static_v.size(0) == bsz * num_heads, \