Scaled Dot-Product Attention是transformer的encoder的multi-head attention的组成部分。3d
因为Scaled Dot-Product Attention是multi-head的构成部分,所以Scaled Dot-Product Attention的数据的输入q,k,v的shape一般咱们会变化为以下:code
(batch, n_head, seqLen, dim) 其中n_head表示multi-head的个数,且n_head*dim = embedSizeorm
整个输入到输出,数据的维度保持不变。blog
temperature表示Scaled,即dim**0.5io
mask表示每一个batch对应样本中若是sequence为pad,则对应的mask为False,所以mask的初始维度为(batchSize, seqLen),为了计算,mask的维度会扩充为(batchSize, 1, 1, seqLen)。form
class ScaledDotProductAttention(nn.Module): """ Compute 'Scaled Dot Product Attention """ def forward(self, query, key, value, mask=None, dropout=None): # (batch, n_head, seq_len, dim) scores = torch.matmul(query, key.transpose(-2, -1))/np.sqrt(query.size(-1)) # (batch, n_head, seq_len_q, seq_len_v) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) # (batch, n_head, seq_len_q, dim) return torch.matmul(p_attn, value), p_attn
注意:class
当QKV来自同一个向量的矩阵变换时称做self-attention;transform
当Q和KV来自不一样的向量的矩阵变换时叫soft-attention;im