add from_pretrained method to EmbeddingBag (#15273)
authorDavid Pollack <david@da3.net>
Wed, 26 Dec 2018 16:31:00 +0000 (08:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 26 Dec 2018 16:35:39 +0000 (08:35 -0800)
Summary:
The `EmbeddingBag` module does not include a `from_pretrained` method like the `Embedding` module.  I added it for consistency between the two modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15273

Differential Revision: D13547842

Pulled By: soumith

fbshipit-source-id: 8ffde51ff0c1e8fc8310263b6f375da88089ff7d

test/test_nn.py
torch/nn/modules/sparse.py

index bf37b98..b08930c 100644 (file)
@@ -2088,10 +2088,26 @@ class TestNN(NNTestCase):
         embedding = nn.Embedding.from_pretrained(a)
         self.assertEqual(a, embedding.weight.data)
 
-        input = Variable(torch.LongTensor([0, 1]))
+        input = torch.LongTensor([0, 1])
         output = embedding(input)
         self.assertEqual(a, output)
 
+    def test_embedding_from_pretrained_options(self):
+        a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
+        opts = {
+            "max_norm": 2.,
+            "norm_type": .5,
+            "scale_grad_by_freq": False,
+            "sparse": True
+        }
+        embedding = nn.Embedding.from_pretrained(a, **opts)
+        input = torch.LongTensor([0, 1])
+        output = embedding(input)
+        # test output and that weight matrix was renormalized
+        self.assertEqual(a, output)
+        self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
+        self.assertTrue(output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all())
+
     def test_embedding_functional(self):
         a = torch.tensor([
             [1, 3, 2],
@@ -2315,6 +2331,32 @@ class TestNN(NNTestCase):
         offset[-1] = 100
         self.assertRaises(ValueError, lambda: es(input.view(-1), offset))
 
+    def test_embeddingbag_from_pretrained(self):
+        a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
+        embeddingbag = nn.EmbeddingBag.from_pretrained(a)
+        self.assertEqual(a, embeddingbag.weight.data)
+
+        input = torch.LongTensor([[0, 1]])
+        output = embeddingbag(input)
+        self.assertEqual(a.mean(0, keepdim=True), output)
+
+    def test_embeddingbag_from_pretrained_options(self):
+        a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
+        opts = {
+            "max_norm": 2.,
+            "norm_type": .5,
+            "scale_grad_by_freq": False,
+            "mode": "max",
+            "sparse": False
+        }
+        embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts)
+
+        input = torch.LongTensor([[0, 1]])
+        output = embeddingbag(input)
+        self.assertEqual(a.max(0, keepdim=True)[0], output)
+        self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
+        self.assertTrue(a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all())
+
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     def test_pool3d_size_one_feature_dim(self):
         # Tests crazy strides for feature dim of size 1
index cd3ea4e..38670c6 100644 (file)
@@ -132,7 +132,9 @@ class Embedding(Module):
         return s.format(**self.__dict__)
 
     @classmethod
-    def from_pretrained(cls, embeddings, freeze=True, sparse=False):
+    def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
+                        max_norm=None, norm_type=2., scale_grad_by_freq=False,
+                        sparse=False):
         r"""Creates Embedding instance from given 2-dimensional FloatTensor.
 
         Args:
@@ -140,8 +142,11 @@ class Embedding(Module):
                 First dimension is being passed to Embedding as 'num_embeddings', second as 'embedding_dim'.
             freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
                 Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
-            sparse (bool, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor.
-                See Notes for more details regarding sparse gradients.
+            padding_idx (int, optional): See module initialization documentation.
+            max_norm (float, optional): See module initialization documentation.
+            norm_type (float, optional): See module initialization documentation. Default ``2``.
+            scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+            sparse (bool, optional): See module initialization documentation.
 
         Examples::
 
@@ -160,8 +165,11 @@ class Embedding(Module):
             num_embeddings=rows,
             embedding_dim=cols,
             _weight=embeddings,
-            sparse=sparse,
-        )
+            padding_idx=padding_idx,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            sparse=sparse)
         embedding.weight.requires_grad = not freeze
         return embedding
 
@@ -230,23 +238,27 @@ class EmbeddingBag(Module):
                 [ 1.1306, -2.5798, -1.0044]])
     """
     __constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
-                     'scale_grad_by_freq', 'mode', 'sparse']
+                     'scale_grad_by_freq', 'mode', 'sparse', '_weight']
 
     def __init__(self, num_embeddings, embedding_dim,
                  max_norm=None, norm_type=2., scale_grad_by_freq=False,
-                 mode='mean', sparse=False):
+                 mode='mean', sparse=False, _weight=None):
         super(EmbeddingBag, self).__init__()
         self.num_embeddings = num_embeddings
         self.embedding_dim = embedding_dim
         self.max_norm = max_norm
         self.norm_type = norm_type
         self.scale_grad_by_freq = scale_grad_by_freq
-        self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
+        if _weight is None:
+            self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
+            self.reset_parameters()
+        else:
+            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
+                'Shape of weight does not match num_embeddings and embedding_dim'
+            self.weight = Parameter(_weight)
         self.mode = mode
         self.sparse = sparse
 
-        self.reset_parameters()
-
     def reset_parameters(self):
         init.normal_(self.weight)
 
@@ -268,4 +280,46 @@ class EmbeddingBag(Module):
         s += ', mode={mode}'
         return s.format(**self.__dict__)
 
+    @classmethod
+    def from_pretrained(cls, embeddings, freeze=True, max_norm=None,
+                        norm_type=2., scale_grad_by_freq=False,
+                        mode='mean', sparse=False):
+        r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
+
+        Args:
+            embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
+                First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
+            freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
+                Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
+            max_norm (float, optional): See module initialization documentation. Default: ``None``
+            norm_type (float, optional): See module initialization documentation. Default ``2``.
+            scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+            mode (string, optional): See module initialization documentation. Default: ``"mean"``
+            sparse (bool, optional): See module initialization documentation. Default: ``False``.
+
+        Examples::
+
+            >>> # FloatTensor containing pretrained weights
+            >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
+            >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
+            >>> # Get embeddings for index 1
+            >>> input = torch.LongTensor([[1, 0]])
+            >>> embeddingbag(input)
+            tensor([[ 2.5000,  3.7000,  4.6500]])
+        """
+        assert embeddings.dim() == 2, \
+            'Embeddings parameter is expected to be 2-dimensional'
+        rows, cols = embeddings.shape
+        embeddingbag = cls(
+            num_embeddings=rows,
+            embedding_dim=cols,
+            _weight=embeddings,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            mode=mode,
+            sparse=sparse)
+        embeddingbag.weight.requires_grad = not freeze
+        return embeddingbag
+
 # TODO: SparseLinear