From cdb8edce7572342cf6f5d98807678e9ed3a1a32e Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 26 Dec 2018 08:31:00 -0800 Subject: [PATCH] add from_pretrained method to EmbeddingBag (#15273) Summary: The `EmbeddingBag` module does not include a `from_pretrained` method like the `Embedding` module. I added it for consistency between the two modules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15273 Differential Revision: D13547842 Pulled By: soumith fbshipit-source-id: 8ffde51ff0c1e8fc8310263b6f375da88089ff7d --- test/test_nn.py | 44 ++++++++++++++++++++++++++- torch/nn/modules/sparse.py | 74 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 107 insertions(+), 11 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index bf37b98..b08930c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2088,10 +2088,26 @@ class TestNN(NNTestCase): embedding = nn.Embedding.from_pretrained(a) self.assertEqual(a, embedding.weight.data) - input = Variable(torch.LongTensor([0, 1])) + input = torch.LongTensor([0, 1]) output = embedding(input) self.assertEqual(a, output) + def test_embedding_from_pretrained_options(self): + a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + opts = { + "max_norm": 2., + "norm_type": .5, + "scale_grad_by_freq": False, + "sparse": True + } + embedding = nn.Embedding.from_pretrained(a, **opts) + input = torch.LongTensor([0, 1]) + output = embedding(input) + # test output and that weight matrix was renormalized + self.assertEqual(a, output) + self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all()) + self.assertTrue(output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()) + def test_embedding_functional(self): a = torch.tensor([ [1, 3, 2], @@ -2315,6 +2331,32 @@ class TestNN(NNTestCase): offset[-1] = 100 self.assertRaises(ValueError, lambda: es(input.view(-1), offset)) + def test_embeddingbag_from_pretrained(self): + a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + embeddingbag = nn.EmbeddingBag.from_pretrained(a) + self.assertEqual(a, embeddingbag.weight.data) + + input = torch.LongTensor([[0, 1]]) + output = embeddingbag(input) + self.assertEqual(a.mean(0, keepdim=True), output) + + def test_embeddingbag_from_pretrained_options(self): + a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + opts = { + "max_norm": 2., + "norm_type": .5, + "scale_grad_by_freq": False, + "mode": "max", + "sparse": False + } + embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts) + + input = torch.LongTensor([[0, 1]]) + output = embeddingbag(input) + self.assertEqual(a.max(0, keepdim=True)[0], output) + self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all()) + self.assertTrue(a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_pool3d_size_one_feature_dim(self): # Tests crazy strides for feature dim of size 1 diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index cd3ea4e..38670c6 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -132,7 +132,9 @@ class Embedding(Module): return s.format(**self.__dict__) @classmethod - def from_pretrained(cls, embeddings, freeze=True, sparse=False): + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): r"""Creates Embedding instance from given 2-dimensional FloatTensor. Args: @@ -140,8 +142,11 @@ class Embedding(Module): First dimension is being passed to Embedding as 'num_embeddings', second as 'embedding_dim'. freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` - sparse (bool, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. + padding_idx (int, optional): See module initialization documentation. + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. Examples:: @@ -160,8 +165,11 @@ class Embedding(Module): num_embeddings=rows, embedding_dim=cols, _weight=embeddings, - sparse=sparse, - ) + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) embedding.weight.requires_grad = not freeze return embedding @@ -230,23 +238,27 @@ class EmbeddingBag(Module): [ 1.1306, -2.5798, -1.0044]]) """ __constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type', - 'scale_grad_by_freq', 'mode', 'sparse'] + 'scale_grad_by_freq', 'mode', 'sparse', '_weight'] def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2., scale_grad_by_freq=False, - mode='mean', sparse=False): + mode='mean', sparse=False, _weight=None): super(EmbeddingBag, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq - self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) + if _weight is None: + self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + else: + assert list(_weight.shape) == [num_embeddings, embedding_dim], \ + 'Shape of weight does not match num_embeddings and embedding_dim' + self.weight = Parameter(_weight) self.mode = mode self.sparse = sparse - self.reset_parameters() - def reset_parameters(self): init.normal_(self.weight) @@ -268,4 +280,46 @@ class EmbeddingBag(Module): s += ', mode={mode}' return s.format(**self.__dict__) + @classmethod + def from_pretrained(cls, embeddings, freeze=True, max_norm=None, + norm_type=2., scale_grad_by_freq=False, + mode='mean', sparse=False): + r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag. + First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'. + freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True`` + max_norm (float, optional): See module initialization documentation. Default: ``None`` + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + mode (string, optional): See module initialization documentation. Default: ``"mean"`` + sparse (bool, optional): See module initialization documentation. Default: ``False``. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([[1, 0]]) + >>> embeddingbag(input) + tensor([[ 2.5000, 3.7000, 4.6500]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embeddingbag = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse) + embeddingbag.weight.requires_grad = not freeze + return embeddingbag + # TODO: SparseLinear -- 2.7.4