embedding = nn.Embedding.from_pretrained(a)
self.assertEqual(a, embedding.weight.data)
- input = Variable(torch.LongTensor([0, 1]))
+ input = torch.LongTensor([0, 1])
output = embedding(input)
self.assertEqual(a, output)
+ def test_embedding_from_pretrained_options(self):
+ a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
+ opts = {
+ "max_norm": 2.,
+ "norm_type": .5,
+ "scale_grad_by_freq": False,
+ "sparse": True
+ }
+ embedding = nn.Embedding.from_pretrained(a, **opts)
+ input = torch.LongTensor([0, 1])
+ output = embedding(input)
+ # test output and that weight matrix was renormalized
+ self.assertEqual(a, output)
+ self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
+ self.assertTrue(output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all())
+
def test_embedding_functional(self):
a = torch.tensor([
[1, 3, 2],
offset[-1] = 100
self.assertRaises(ValueError, lambda: es(input.view(-1), offset))
+ def test_embeddingbag_from_pretrained(self):
+ a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
+ embeddingbag = nn.EmbeddingBag.from_pretrained(a)
+ self.assertEqual(a, embeddingbag.weight.data)
+
+ input = torch.LongTensor([[0, 1]])
+ output = embeddingbag(input)
+ self.assertEqual(a.mean(0, keepdim=True), output)
+
+ def test_embeddingbag_from_pretrained_options(self):
+ a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
+ opts = {
+ "max_norm": 2.,
+ "norm_type": .5,
+ "scale_grad_by_freq": False,
+ "mode": "max",
+ "sparse": False
+ }
+ embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts)
+
+ input = torch.LongTensor([[0, 1]])
+ output = embeddingbag(input)
+ self.assertEqual(a.max(0, keepdim=True)[0], output)
+ self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
+ self.assertTrue(a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all())
+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_pool3d_size_one_feature_dim(self):
# Tests crazy strides for feature dim of size 1
return s.format(**self.__dict__)
@classmethod
- def from_pretrained(cls, embeddings, freeze=True, sparse=False):
+ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
+ max_norm=None, norm_type=2., scale_grad_by_freq=False,
+ sparse=False):
r"""Creates Embedding instance from given 2-dimensional FloatTensor.
Args:
First dimension is being passed to Embedding as 'num_embeddings', second as 'embedding_dim'.
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
- sparse (bool, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor.
- See Notes for more details regarding sparse gradients.
+ padding_idx (int, optional): See module initialization documentation.
+ max_norm (float, optional): See module initialization documentation.
+ norm_type (float, optional): See module initialization documentation. Default ``2``.
+ scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+ sparse (bool, optional): See module initialization documentation.
Examples::
num_embeddings=rows,
embedding_dim=cols,
_weight=embeddings,
- sparse=sparse,
- )
+ padding_idx=padding_idx,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ scale_grad_by_freq=scale_grad_by_freq,
+ sparse=sparse)
embedding.weight.requires_grad = not freeze
return embedding
[ 1.1306, -2.5798, -1.0044]])
"""
__constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
- 'scale_grad_by_freq', 'mode', 'sparse']
+ 'scale_grad_by_freq', 'mode', 'sparse', '_weight']
def __init__(self, num_embeddings, embedding_dim,
max_norm=None, norm_type=2., scale_grad_by_freq=False,
- mode='mean', sparse=False):
+ mode='mean', sparse=False, _weight=None):
super(EmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
- self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
+ if _weight is None:
+ self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
+ self.reset_parameters()
+ else:
+ assert list(_weight.shape) == [num_embeddings, embedding_dim], \
+ 'Shape of weight does not match num_embeddings and embedding_dim'
+ self.weight = Parameter(_weight)
self.mode = mode
self.sparse = sparse
- self.reset_parameters()
-
def reset_parameters(self):
init.normal_(self.weight)
s += ', mode={mode}'
return s.format(**self.__dict__)
+ @classmethod
+ def from_pretrained(cls, embeddings, freeze=True, max_norm=None,
+ norm_type=2., scale_grad_by_freq=False,
+ mode='mean', sparse=False):
+ r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
+
+ Args:
+ embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
+ First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
+ freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
+ Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
+ max_norm (float, optional): See module initialization documentation. Default: ``None``
+ norm_type (float, optional): See module initialization documentation. Default ``2``.
+ scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+ mode (string, optional): See module initialization documentation. Default: ``"mean"``
+ sparse (bool, optional): See module initialization documentation. Default: ``False``.
+
+ Examples::
+
+ >>> # FloatTensor containing pretrained weights
+ >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
+ >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
+ >>> # Get embeddings for index 1
+ >>> input = torch.LongTensor([[1, 0]])
+ >>> embeddingbag(input)
+ tensor([[ 2.5000, 3.7000, 4.6500]])
+ """
+ assert embeddings.dim() == 2, \
+ 'Embeddings parameter is expected to be 2-dimensional'
+ rows, cols = embeddings.shape
+ embeddingbag = cls(
+ num_embeddings=rows,
+ embedding_dim=cols,
+ _weight=embeddings,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ scale_grad_by_freq=scale_grad_by_freq,
+ mode=mode,
+ sparse=sparse)
+ embeddingbag.weight.requires_grad = not freeze
+ return embeddingbag
+
# TODO: SparseLinear