Summary:
Fixes : #6469
1. `ATen/native/native_functions.yml` had [dispatch](https://github.com/pytorch/pytorch/blob/
03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/native_functions.yaml#L451-L455) variants for for `embedding_dense_backward` , however `embedding_backward` explicitly made [call](https://github.com/pytorch/pytorch/blob/
03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L35-L45) to it, thus leading to error.
2. In case of CUDA type tensor, the function crashed used to crash on dereferencing of indices's data [pointer](https://github.com/pytorch/pytorch/blob/
03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L93).
Both have been solved and checked against (on CUDA and CPU)
1. As mentioned in the issue
```
import torch
class Test(torch.nn.Module):
def __init__(self):
super(Test,self).__init__()
self.embd = torch.nn.Embedding(1000, 100)
self.dense = torch.nn.Linear(100, 1)
def forward(self, inp):
inp = self.embd(inp)
return self.dense(inp)
test = Test()
inp = torch.tensor([0,1,2,1,1])
out = test(inp)
raw_loss = out.mean(dim=0)
loss_grad = torch.autograd.grad(outputs=raw_loss,
inputs=list(test.parameters()),
retain_graph=True, create_graph=True, only_inputs=True)
norm = sum([param.norm()**2 for param in loss_grad])
loss = raw_loss + norm
loss.backward(retain_graph=True)
print(test.embd.weight.grad)
```
2. Test Script
```
import torch
import time
start = time.time()
l = [1,1]*100
input = torch.tensor([[1,0],[1,0]],device='cpu')
embedding_matrix = torch.tensor([[1.0,3.0],[2.0,4]],requires_grad=True,device='cpu')
sq = embedding_matrix * embedding_matrix
emb = torch.nn.functional.embedding(input, sq,scale_grad_by_freq=False)
print('Embedding Matrix')
print(embedding_matrix)
print('-----------------')
sum_ = emb.sum()#prod.sum()
loss_grad, = torch.autograd.grad(outputs=sum_,inputs=embedding_matrix,create_graph=True)
print('Gradient')
print(loss_grad)
print('-----------------')
sum2_ = sum_ + loss_grad.sum()
print(sum2_)
sum2_.backward()
print(embedding_matrix.grad)
print(time.time() - start)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9078
Reviewed By: ezyang
Differential Revision:
D14691901
Pulled By: soumith
fbshipit-source-id:
78e2612ba39080be564c876311671eb5a0119a0f
- func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
matches_jit_signature: True
-- func: embedding_dense_backward(Tensor grad, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
+- func: embedding_dense_backward(Tensor grad_output, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
dispatch:
CPU: embedding_dense_backward_cpu
CUDA: embedding_dense_backward_cuda
self.assertTrue(embedding.weight.grad.is_sparse)
self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
+ def _test_embedding_dense_grad(self, dev):
+ embd = nn.Embedding(20, 20).to(dev)
+ weight = embd.weight
+
+ def fn_wrapper(dev):
+ def fn(weight):
+ inp = torch.tensor([[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long).to(dev)
+ return torch.nn.functional.embedding(inp, weight)
+ return fn
+
+ fn = fn_wrapper(dev)
+ _assertGradAndGradgradChecks(self, fn, (weight, ))
+
+ def test_embedding_dense_grad(self):
+ self._test_embedding_dense_grad("cpu")
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ @skipIfRocm
+ def test_embedding_dense_grad_cuda(self):
+ self._test_embedding_dense_grad("cuda")
+
def test_embedding_sparse_backward(self):
embedding = nn.Embedding(10, 3, sparse=True)
embedding.zero_grad()
embedding.zero_grad()
self.assertEqual(after, pre)
+ # test double backward
+ emb_sum = embedding(indices).sum()
+ emb_grad = torch.autograd.grad(outputs=emb_sum, inputs=list(embedding.parameters()), retain_graph=True)
+ scalar = emb_grad[0].sum() + emb_sum
+ scalar.backward()
+ after = (embedding.weight + embedding.weight.grad)[padding_idx]
+ embedding.zero_grad()
+ self.assertEqual(after, pre)
+
def test_embedding_max_norm(self):
embedding = nn.Embedding(22, 5, max_norm=1.0)
input = Variable(torch.LongTensor([2, 8, 8, 6]))
indices: not_differentiable
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
+- name: embedding_dense_backward(Tensor grad_output, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq)
+ grad_output: embedding_dense_double_backward(grad, indices)
+
- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
indices: not_differentiable
offsets: not_differentiable
return at::constant_pad_nd(grad, negated_pad, 0);
}
+Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) {
+ // since first backward takes care of padding_idx
+ // and scaling by frequency, we don't need to worry
+ // about it here.
+ auto gg_weight = grad.index_select(0, indices.reshape(-1));
+
+ // reshape gradient as per the shape of indices
+ auto size = indices.sizes().vec();
+ size.push_back(-1);
+
+ return gg_weight.view(size);
+}
+
} // anonymous namespace
${autograd_function_definitions}