Fix dense Embedding to work with double backward (#9078)
authorkshitij12345 <kshitijkalambarkar@gmail.com>
Wed, 3 Apr 2019 16:16:29 +0000 (09:16 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 16:50:34 +0000 (09:50 -0700)
Summary:
Fixes : #6469

1. `ATen/native/native_functions.yml` had [dispatch](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/native_functions.yaml#L451-L455) variants for for `embedding_dense_backward` , however `embedding_backward` explicitly made [call](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L35-L45) to it, thus leading to error.

2. In case of CUDA type tensor, the function crashed used to crash on dereferencing of indices's data [pointer](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L93).

Both have been solved and checked against (on CUDA and CPU)

1.  As mentioned in the issue
```
import torch

class Test(torch.nn.Module):

    def __init__(self):
        super(Test,self).__init__()
        self.embd = torch.nn.Embedding(1000, 100)
        self.dense = torch.nn.Linear(100, 1)

    def forward(self, inp):
        inp = self.embd(inp)
        return self.dense(inp)

test = Test()
inp = torch.tensor([0,1,2,1,1])
out = test(inp)
raw_loss = out.mean(dim=0)

loss_grad = torch.autograd.grad(outputs=raw_loss,
                         inputs=list(test.parameters()),
                         retain_graph=True, create_graph=True, only_inputs=True)
norm = sum([param.norm()**2 for param in loss_grad])
loss = raw_loss + norm

loss.backward(retain_graph=True)

print(test.embd.weight.grad)

```

2. Test Script
```
import torch
import time
start = time.time()
l = [1,1]*100
input = torch.tensor([[1,0],[1,0]],device='cpu')
embedding_matrix = torch.tensor([[1.0,3.0],[2.0,4]],requires_grad=True,device='cpu')

sq = embedding_matrix * embedding_matrix
emb = torch.nn.functional.embedding(input, sq,scale_grad_by_freq=False)

print('Embedding Matrix')
print(embedding_matrix)
print('-----------------')

sum_ = emb.sum()#prod.sum()

loss_grad, = torch.autograd.grad(outputs=sum_,inputs=embedding_matrix,create_graph=True)

print('Gradient')
print(loss_grad)
print('-----------------')

sum2_ = sum_ + loss_grad.sum()
print(sum2_)
sum2_.backward()

print(embedding_matrix.grad)
print(time.time() - start)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9078

Reviewed By: ezyang

Differential Revision: D14691901

Pulled By: soumith

fbshipit-source-id: 78e2612ba39080be564c876311671eb5a0119a0f

aten/src/ATen/native/native_functions.yaml
test/test_nn.py
tools/autograd/derivatives.yaml
tools/autograd/templates/Functions.cpp

index a25bb58..e1cd053 100644 (file)
 - func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
   matches_jit_signature: True
 
-- func: embedding_dense_backward(Tensor grad, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
+- func: embedding_dense_backward(Tensor grad_output, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
   dispatch:
     CPU: embedding_dense_backward_cpu
     CUDA: embedding_dense_backward_cuda
index 25fb4fd..5ba8e62 100644 (file)
@@ -2049,6 +2049,27 @@ class TestNN(NNTestCase):
         self.assertTrue(embedding.weight.grad.is_sparse)
         self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
 
+    def _test_embedding_dense_grad(self, dev):
+        embd = nn.Embedding(20, 20).to(dev)
+        weight = embd.weight
+
+        def fn_wrapper(dev):
+            def fn(weight):
+                inp = torch.tensor([[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long).to(dev)
+                return torch.nn.functional.embedding(inp, weight)
+            return fn
+
+        fn = fn_wrapper(dev)
+        _assertGradAndGradgradChecks(self, fn, (weight, ))
+
+    def test_embedding_dense_grad(self):
+        self._test_embedding_dense_grad("cpu")
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    @skipIfRocm
+    def test_embedding_dense_grad_cuda(self):
+        self._test_embedding_dense_grad("cuda")
+
     def test_embedding_sparse_backward(self):
         embedding = nn.Embedding(10, 3, sparse=True)
         embedding.zero_grad()
@@ -2111,6 +2132,15 @@ class TestNN(NNTestCase):
                 embedding.zero_grad()
                 self.assertEqual(after, pre)
 
+                # test double backward
+                emb_sum = embedding(indices).sum()
+                emb_grad = torch.autograd.grad(outputs=emb_sum, inputs=list(embedding.parameters()), retain_graph=True)
+                scalar = emb_grad[0].sum() + emb_sum
+                scalar.backward()
+                after = (embedding.weight + embedding.weight.grad)[padding_idx]
+                embedding.zero_grad()
+                self.assertEqual(after, pre)
+
     def test_embedding_max_norm(self):
         embedding = nn.Embedding(22, 5, max_norm=1.0)
         input = Variable(torch.LongTensor([2, 8, 8, 6]))
index bea70b5..8e9fd2d 100644 (file)
   indices: not_differentiable
   weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
 
+- name: embedding_dense_backward(Tensor grad_output, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq)
+  grad_output: embedding_dense_double_backward(grad, indices)
+
 - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
   indices: not_differentiable
   offsets: not_differentiable
index c927624..4736ca2 100644 (file)
@@ -2083,6 +2083,19 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
   return at::constant_pad_nd(grad, negated_pad, 0);
 }
 
+Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) {
+  // since first backward takes care of padding_idx
+  // and scaling by frequency, we don't need to worry
+  // about it here.
+  auto gg_weight = grad.index_select(0, indices.reshape(-1));
+
+  // reshape gradient as per the shape of indices
+  auto size = indices.sizes().vec();
+  size.push_back(-1);
+
+  return gg_weight.view(size);
+}
+
 } // anonymous namespace
 
 ${autograd_function_definitions}