From 0916b5419a5703f7f81822b66f31642c0613ccad Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 3 Apr 2019 09:16:29 -0700 Subject: [PATCH] Fix dense Embedding to work with double backward (#9078) Summary: Fixes : #6469 1. `ATen/native/native_functions.yml` had [dispatch](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/native_functions.yaml#L451-L455) variants for for `embedding_dense_backward` , however `embedding_backward` explicitly made [call](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L35-L45) to it, thus leading to error. 2. In case of CUDA type tensor, the function crashed used to crash on dereferencing of indices's data [pointer](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L93). Both have been solved and checked against (on CUDA and CPU) 1. As mentioned in the issue ``` import torch class Test(torch.nn.Module): def __init__(self): super(Test,self).__init__() self.embd = torch.nn.Embedding(1000, 100) self.dense = torch.nn.Linear(100, 1) def forward(self, inp): inp = self.embd(inp) return self.dense(inp) test = Test() inp = torch.tensor([0,1,2,1,1]) out = test(inp) raw_loss = out.mean(dim=0) loss_grad = torch.autograd.grad(outputs=raw_loss, inputs=list(test.parameters()), retain_graph=True, create_graph=True, only_inputs=True) norm = sum([param.norm()**2 for param in loss_grad]) loss = raw_loss + norm loss.backward(retain_graph=True) print(test.embd.weight.grad) ``` 2. Test Script ``` import torch import time start = time.time() l = [1,1]*100 input = torch.tensor([[1,0],[1,0]],device='cpu') embedding_matrix = torch.tensor([[1.0,3.0],[2.0,4]],requires_grad=True,device='cpu') sq = embedding_matrix * embedding_matrix emb = torch.nn.functional.embedding(input, sq,scale_grad_by_freq=False) print('Embedding Matrix') print(embedding_matrix) print('-----------------') sum_ = emb.sum()#prod.sum() loss_grad, = torch.autograd.grad(outputs=sum_,inputs=embedding_matrix,create_graph=True) print('Gradient') print(loss_grad) print('-----------------') sum2_ = sum_ + loss_grad.sum() print(sum2_) sum2_.backward() print(embedding_matrix.grad) print(time.time() - start) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/9078 Reviewed By: ezyang Differential Revision: D14691901 Pulled By: soumith fbshipit-source-id: 78e2612ba39080be564c876311671eb5a0119a0f --- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_nn.py | 30 ++++++++++++++++++++++++++++++ tools/autograd/derivatives.yaml | 3 +++ tools/autograd/templates/Functions.cpp | 13 +++++++++++++ 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a25bb58..e1cd053 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -800,7 +800,7 @@ - func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor matches_jit_signature: True -- func: embedding_dense_backward(Tensor grad, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor +- func: embedding_dense_backward(Tensor grad_output, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor dispatch: CPU: embedding_dense_backward_cpu CUDA: embedding_dense_backward_cuda diff --git a/test/test_nn.py b/test/test_nn.py index 25fb4fd..5ba8e62 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2049,6 +2049,27 @@ class TestNN(NNTestCase): self.assertTrue(embedding.weight.grad.is_sparse) self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) + def _test_embedding_dense_grad(self, dev): + embd = nn.Embedding(20, 20).to(dev) + weight = embd.weight + + def fn_wrapper(dev): + def fn(weight): + inp = torch.tensor([[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long).to(dev) + return torch.nn.functional.embedding(inp, weight) + return fn + + fn = fn_wrapper(dev) + _assertGradAndGradgradChecks(self, fn, (weight, )) + + def test_embedding_dense_grad(self): + self._test_embedding_dense_grad("cpu") + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @skipIfRocm + def test_embedding_dense_grad_cuda(self): + self._test_embedding_dense_grad("cuda") + def test_embedding_sparse_backward(self): embedding = nn.Embedding(10, 3, sparse=True) embedding.zero_grad() @@ -2111,6 +2132,15 @@ class TestNN(NNTestCase): embedding.zero_grad() self.assertEqual(after, pre) + # test double backward + emb_sum = embedding(indices).sum() + emb_grad = torch.autograd.grad(outputs=emb_sum, inputs=list(embedding.parameters()), retain_graph=True) + scalar = emb_grad[0].sum() + emb_sum + scalar.backward() + after = (embedding.weight + embedding.weight.grad)[padding_idx] + embedding.zero_grad() + self.assertEqual(after, pre) + def test_embedding_max_norm(self): embedding = nn.Embedding(22, 5, max_norm=1.0) input = Variable(torch.LongTensor([2, 8, 8, 6])) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index bea70b5..8e9fd2d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -945,6 +945,9 @@ indices: not_differentiable weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) +- name: embedding_dense_backward(Tensor grad_output, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) + grad_output: embedding_dense_double_backward(grad, indices) + - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) indices: not_differentiable offsets: not_differentiable diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index c927624..4736ca2 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -2083,6 +2083,19 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { return at::constant_pad_nd(grad, negated_pad, 0); } +Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) { + // since first backward takes care of padding_idx + // and scaling by frequency, we don't need to worry + // about it here. + auto gg_weight = grad.index_select(0, indices.reshape(-1)); + + // reshape gradient as per the shape of indices + auto size = indices.sizes().vec(); + size.push_back(-1); + + return gg_weight.view(size); +} + } // anonymous namespace ${autograd_function_definitions} -- 2.7.4