From 4b20fc826da7d15c67ab632d95b0a22ea55daa84 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 11 Apr 2019 08:04:32 -0700 Subject: [PATCH] Import MultiheadAttention to PyTorch (#18334) Summary: Import MultiheadAttention into the core pytorch framework. Users now can import MultiheadAttention directly from torch.nn. See "Attention Is All You Need" for more details related to MultiheadAttention function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18334 Differential Revision: D14577966 Pulled By: zhangguanheng66 fbshipit-source-id: 756c0deff623f3780651d9f9a70ce84516c806d3 --- test/test_nn.py | 187 +++++++++++++++++++++++++++++++++ torch/nn/modules/__init__.py | 4 +- torch/nn/modules/activation.py | 227 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 415 insertions(+), 3 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 2ac6188..df8e645 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -37,6 +37,8 @@ from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \ get_weight, smoothl1loss_reference, kldivloss_reference, \ ctcloss_reference, new_module_tests +from torch.nn import MultiheadAttention + # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -3176,6 +3178,191 @@ class TestNN(NNTestCase): output = m(sigmoid(input), target) verify_reduction_scalars(input, reduction, output) + def test_multihead_attention(self): + def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=False, src_lengths=None): + """ Numpy-based reference implementation of scaled dot attention + for testing""" + QKT = _batchmatmul( + Q, + np.transpose(K, axes=[0, 1, 3, 2]) + / np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head) + ) + if unseen_mask or src_lengths is not None: + b1, b2, s1, s2 = QKT.shape + # assert s1 == s2 + for i in range(b1): + for j in range(b2): + for m in range(s1): + for n in range(s2): + if unseen_mask and n > m: + QKT[i, j, m, n] = -np.inf + if src_lengths is not None and n >= src_lengths[i]: + QKT[i, j, m, n] = -np.inf + reference = _softmax(QKT) + reference = _batchmatmul(reference, V) + return reference + + def _batchmatmul(a, b): # batchmatmul over 4 dim matrix + """ Numpy-based batch matrix multiply over 4 dim matrix""" + assert a.shape[0] == b.shape[0] + assert a.shape[1] == b.shape[1] + retval = np.zeros( + (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32 + ) + for i in range(a.shape[0]): + for j in range(a.shape[1]): + retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :]) + return retval + + def _softmax(x): # softmax over 4 dim matrix + """ Numpy-based reference softmax over 4 dim matrix""" + output = np.zeros(x.shape, dtype=np.float32) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + for k in range(x.shape[2]): + x_curr = x[i, j, k, :] + e_x = np.exp(x_curr - np.amax(x_curr)) + output[i, j, k, :] = e_x / np.sum(e_x) + return output + + def _generate_src_lengths(batch_size, seq_len): + src_lengths = np.array([random.randint(1, seq_len) for i in range(batch_size)]) + + # max source length has to equal seq_len, so randomly choose + # one example to have source length = seq_len + max_len_example_i = random.randint(0, batch_size - 1) + src_lengths[max_len_example_i] = seq_len + + src_lengths_tensor = torch.from_numpy(src_lengths).int() + return src_lengths, src_lengths_tensor + + def _split_heads_ref(X, dims, nheads, d_head): + X_split = np.reshape(X, dims[:2] + [nheads, d_head]) + X_split_transposed = np.transpose(X_split, [0, 2, 1, 3]) + reference = np.reshape(X_split_transposed, [dims[0], nheads, dims[1], d_head]) + return reference + + def _combine_heads_ref(X, dims, nheads, d_head): + X_transposed = np.transpose(X, [0, 2, 1, 3]) + reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head]) + return reference + + def _fc(X, X_name, module, start=None, end=None): + X_fc_b = None + X_fc_w = None + for name, param in module.named_parameters(): + if X_name + "weight" in name: + if X_fc_w is not None: + raise Exception("Duplicate FC name found") + X_fc_w = param[start:end, :].detach().numpy() + elif X_name + "bias" in name: + if X_fc_b is not None: + raise Exception("Duplicate FC name found") + X_fc_b = param[start:end].detach().numpy() + return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b + + def _create_src_lengths_mask(batch_size, src_lengths): + """ + Generate boolean mask to prevent attention beyond the end of source + + Inputs: + batch_size : int + src_lengths : [batch_size] of sentence lengths + + Outputs: + [batch_size, max_src_len] + """ + max_srclen = src_lengths.max() + src_indices = torch.arange(0, max_srclen).unsqueeze(0).type_as(src_lengths) + src_indices = src_indices.expand(batch_size, max_srclen) + src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen) + # returns [batch_size, max_seq_len] + return (src_indices < src_lengths).int().detach() + + def _multihead_attn_test_helper(use_src_lengths): + for _ in range(100): + batch_sz, seq_len = [random.randint(2, 10) for r in range(2)] + d_head = random.randint(3, 10) + nheads = random.randint(3, 10) + d_model = d_head * nheads + dims = [batch_sz, seq_len, d_model] + + src_lengths = None + src_lengths_tensor = None + if use_src_lengths: + src_lengths, src_lengths_tensor = _generate_src_lengths( + batch_size=batch_sz, seq_len=seq_len + ) + + decoder_state = np.random.rand(batch_sz, d_model).astype(np.float64) + K = np.random.rand(*dims).astype(np.float64) + V = K + Q = np.expand_dims(decoder_state, 1) + + decoder_state_tensor = torch.from_numpy(decoder_state).double() + source_hid_tensor = torch.from_numpy(K).double().transpose(0, 1) + + multihead_attn_module = MultiheadAttention(d_model, nheads) + + _batch_size = decoder_state_tensor.shape[0] + _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1) + _V = source_hid_tensor + _K = source_hid_tensor + src_len_mask = None + if src_lengths is not None and use_src_lengths: + # [batch_size, 1, seq_len] + src_len_mask_int = _create_src_lengths_mask( + batch_size=_batch_size, src_lengths=src_lengths_tensor + ) + src_len_mask = src_len_mask_int != 1 + + result = multihead_attn_module( + _Q, _K, _V, + key_padding_mask=src_len_mask, + need_weights=True)[0].squeeze(0).detach().numpy() + + Q_fc = _fc(Q, "in_proj_", multihead_attn_module, end=d_model) + K_fc = _fc( + K, "in_proj_", multihead_attn_module, start=d_model, end=2 * d_model + ) + V_fc = _fc(V, "in_proj_", multihead_attn_module, start=2 * d_model) + + Q_split = _split_heads_ref( + Q_fc, [batch_sz, 1, d_model], nheads, d_head + ) + K_split = _split_heads_ref(K_fc, dims, nheads, d_head) + V_split = _split_heads_ref(V_fc, dims, nheads, d_head) + + attn_heads = _scaled_dot_attn_ref( + Q=Q_split, + K=K_split, + V=V_split, + dims=Q_split.shape, + src_lengths=src_lengths, + ) + + combined_attn_heads = _combine_heads_ref( + X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head + ) + + reference = _fc( + combined_attn_heads, "out_proj.", multihead_attn_module + ) + reference = np.squeeze(reference, axis=1) + + # result = reference + self.assertEqual(tuple(result.shape), (batch_sz, d_model)) + np.testing.assert_allclose(result, reference, atol=1e-5) + + def test_multihead_attn_no_masking(): + _multihead_attn_test_helper(use_src_lengths=None) + + def test_multihead_attn_with_src_lengths(): + _multihead_attn_test_helper(use_src_lengths=True) + + test_multihead_attn_no_masking() # Test MultiheadAttention without masking + test_multihead_attn_with_src_lengths() # Test MultiheadAttention with src lengths + def test_normalize(self): inputs = torch.randn(1, 3, 4, 4, requires_grad=True) self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,))) diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 71c1267..55b294b 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -4,7 +4,7 @@ from .conv import Conv1d, Conv2d, Conv3d, \ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \ Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, Hardshrink, LeakyReLU, LogSigmoid, \ - Softplus, Softshrink, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU + Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \ CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \ MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \ @@ -32,7 +32,7 @@ __all__ = [ 'Module', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'Hardshrink', - 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU', 'Softsign', 'Softmin', + 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin', 'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'CTCLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss', 'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 3851654..e6ce9b8 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1,7 +1,10 @@ import warnings import torch +from . import Linear +from torch.nn.init import xavier_uniform_ +from torch.nn.init import constant_ +from torch.nn.init import xavier_normal_ from torch.nn.parameter import Parameter - from .module import Module from .. import functional as F from ..._jit_internal import weak_module, weak_script_method @@ -670,6 +673,228 @@ class Softshrink(Module): @weak_module +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + xavier_uniform_(self.in_proj_weight[:self.embed_dim, :]) + xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim * 2), :]) + xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :]) + + xavier_uniform_(self.out_proj.weight) + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + @weak_script_method + def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, + need_weights=True, static_kv=False, attn_mask=None): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() + kv_same = key.data_ptr() == value.data_ptr() + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + assert key.size() == value.size() + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert kv_same and not qkv_same + key = value = None + else: + saved_state = None + + if qkv_same: + # self-attention + q, k, v = self._in_proj_qkv(query) + elif kv_same: + # encoder-decoder attention + q = self._in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k, v = self._in_proj_kv(key) + else: + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + k = torch.cat((prev_key, k), dim=1) + if 'prev_value' in saved_state: + prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + v = torch.cat((prev_value, v), dim=1) + saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) + + self._set_input_buffer(incremental_state, saved_state) + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_output_weights = F.softmax( + attn_output_weights.float(), dim=-1, + dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype) + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = self.out_proj(attn_output) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads + else: + attn_output_weights = None + + return attn_output, attn_output_weights + + def _in_proj_qkv(self, query): + return self._in_proj(query).chunk(3, dim=-1) + + def _in_proj_kv(self, key): + return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) + + def _in_proj_q(self, query): + return self._in_proj(query, end=self.embed_dim) + + def _in_proj_k(self, key): + return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) + + def _in_proj_v(self, value): + return self._in_proj(value, start=2 * self.embed_dim) + + def _in_proj(self, input, start=0, end=None): + weight = self.in_proj_weight + bias = self.in_proj_bias + weight = weight[start:end, :] + if bias is not None: + bias = bias[start:end] + return F.linear(input, weight, bias) + + +@weak_module class PReLU(Module): r"""Applies the element-wise function: -- 2.7.4